Commit cf66c525 authored by qianyj's avatar qianyj
Browse files

update some TF file

parent 6b6f8b0c
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf2_utils_2x_wide."""
import numpy as np
import tensorflow as tf
from official.modeling.fast_training.experimental import tf2_utils_2x_wide
class Tf2Utils2XWideTest(tf.test.TestCase):
def test_expand_vector(self):
x = np.array([1, 2])
self.assertAllClose(tf2_utils_2x_wide.expand_vector(x),
np.array([1, 1, 2, 2]))
def test_expand_matrix(self):
x = np.array([[1, 2], [3, 4]])
x = tf2_utils_2x_wide.expand_2_axes(x, epsilon=0.1)
self.assertAllClose(x[0, :] + x[1, :], np.array([1, 1, 2, 2]))
self.assertAllClose(x[2, :] + x[3, :], np.array([3, 3, 4, 4]))
def test_expand_matrix_axis_0(self):
x = np.array([[1, 2], [3, 4]])
x = tf2_utils_2x_wide.expand_1_axis(x, axis=0, epsilon=0.1)
self.assertAllClose(x[0, :] + x[1, :], np.array([1, 2]))
self.assertAllClose(x[2, :] + x[3, :], np.array([3, 4]))
def test_expand_matrix_axis_1(self):
x = np.array([[1, 2], [3, 4]])
x = tf2_utils_2x_wide.expand_1_axis(x, axis=-1, epsilon=0.1)
self.assertAllClose(x[:, 0] + x[:, 1], np.array([1, 3]))
self.assertAllClose(x[:, 2] + x[:, 3], np.array([2, 4]))
def test_expand_3d_tensor(self):
x0 = np.array([10, 11])
x1 = np.array([10, 10, 11, 11])
w0 = np.random.rand(2, 2)
w1 = tf2_utils_2x_wide.expand_2_axes(w0, epsilon=0.1)
o0 = np.matmul(x0, w0)
o1 = np.matmul(x1, w1)
self.assertAllClose(np.repeat(o0, 2, axis=-1), o1)
def test_expand_3d_tensor_axis_0(self):
x0 = np.array([10, 11])
x1 = np.array([10, 10, 11, 11])
w0 = np.random.rand(2, 2)
w1 = tf2_utils_2x_wide.expand_1_axis(w0, axis=0, epsilon=0.1)
o0 = np.matmul(x0, w0)
o1 = np.matmul(x1, w1)
self.assertAllClose(o0, o1)
def test_expand_3d_tensor_axis_2(self):
x = np.array([10, 11])
w0 = np.random.rand(2, 2)
w1 = tf2_utils_2x_wide.expand_1_axis(w0, axis=-1, epsilon=0.1)
o0 = np.matmul(x, w0)
o1 = np.matmul(x, w1)
self.assertAllClose(o0, np.sum(o1.reshape(2, 2), axis=-1))
def test_end_to_end(self):
"""Covers expand_vector, expand_2_axes, and expand_1_axis."""
model_narrow = tf.keras.Sequential()
model_narrow.add(tf.keras.Input(shape=(3,)))
model_narrow.add(tf.keras.layers.Dense(4))
model_narrow.add(tf.keras.layers.Dense(4))
model_narrow.add(tf.keras.layers.Dense(1))
model_wide = tf.keras.Sequential()
model_wide.add(tf.keras.Input(shape=(6,)))
model_wide.add(tf.keras.layers.Dense(8))
model_wide.add(tf.keras.layers.Dense(8))
model_wide.add(tf.keras.layers.Dense(1))
x0 = np.array([[1, 2, 3]])
x1 = np.array([[1, 1, 2, 2, 3, 3]])
# Call model once to build variables first.
_, _ = model_narrow(x0), model_wide(x1)
tf2_utils_2x_wide.model_to_model_2x_wide(
model_narrow, model_wide, epsilon=0.2)
self.assertAllClose(model_narrow(x0), model_wide(x1),
rtol=1e-05, atol=1e-05)
if __name__ == "__main__":
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base ProgressivePolicy definition for progressive training.
To write a progressive model, subclass ProgressivePolicy and implement its
abstract methods to handle each training stage.
"""
import abc
import dataclasses
from typing import Any, Mapping
from absl import logging
import six
import tensorflow as tf
from official.common import streamz_counters
from official.modeling.fast_training.progressive import utils
from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class ProgressiveConfig(base_config.Config):
pass
@six.add_metaclass(abc.ABCMeta)
class ProgressivePolicy:
"""The APIs for handling progressive training stages.
Attributes:
cur_model: The model for the current progressive training stage.
cur_train_dataset: The train dataset function for the current stage.
cur_eval_dataset: The eval dataset function for the current stage.
cur_optimizer: The optimizer for the current stage.
cur_checkpoint_items: Items to be saved in and restored from checkpoints,
for the progressive trainer.
is_last_stage: Whether it is currently in the last stage.
Interfaces:
is_stage_advancing: Returns if progressive training is advancing to the
next stage.
update_pt_stage: Update progressive training stage.
"""
def __init__(self):
"""Initialize stage policy."""
self._cur_train_dataset = None
self._cur_eval_dataset = None
self._volatiles = utils.VolatileTrackable(optimizer=None, model=None)
stage_id = 0
self._stage_id = tf.Variable(
stage_id,
trainable=False,
dtype=tf.int64,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=[])
self._volatiles.reassign_trackable(
optimizer=self.get_optimizer(stage_id),
model=self.get_model(stage_id, old_model=None)) # pytype: disable=wrong-arg-types # typed-keras
streamz_counters.progressive_policy_creation_counter.get_cell(
).increase_by(1)
def compute_stage_id(self, global_step: int) -> int:
for stage_id in range(self.num_stages()):
global_step -= self.num_steps(stage_id)
if global_step < 0:
return stage_id
logging.error('Global step %d found no matching progressive stages. '
'Default to the last stage.', global_step)
return self.num_stages() - 1
@abc.abstractmethod
def num_stages(self) -> int:
"""Return the total number of progressive stages."""
pass
@abc.abstractmethod
def num_steps(self, stage_id: int) -> int:
"""Return the total number of steps in this stage."""
pass
@abc.abstractmethod
def get_model(self,
stage_id: int,
old_model: tf.keras.Model = None) -> tf.keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
"""Return model for this stage. For initialization, `old_model` = None."""
pass
@abc.abstractmethod
def get_optimizer(self, stage_id: int) -> tf.keras.optimizers.Optimizer:
"""Return optimizer for this stage."""
pass
@abc.abstractmethod
def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
"""Return training Dataset for this stage."""
pass
@abc.abstractmethod
def get_eval_dataset(self, stage_id: int) -> tf.data.Dataset:
"""Return evaluation Dataset for this stage."""
pass
@property
def cur_model(self) -> tf.keras.Model:
return self._volatiles.model
@property
def cur_train_dataset(self) -> tf.data.Dataset:
if self._cur_train_dataset is None:
self._cur_train_dataset = self.get_train_dataset(self._stage_id.numpy())
return self._cur_train_dataset
@property
def cur_eval_dataset(self) -> tf.data.Dataset:
if self._cur_eval_dataset is None:
self._cur_eval_dataset = self.get_eval_dataset(self._stage_id.numpy())
return self._cur_eval_dataset
@property
def cur_optimizer(self) -> tf.keras.optimizers.Optimizer:
return self._volatiles.optimizer
@property
def is_last_stage(self) -> bool:
stage_id = self._stage_id.numpy()
return stage_id >= self.num_stages() - 1
@property
def cur_checkpoint_items(self) -> Mapping[str, Any]:
return dict(stage_id=self._stage_id, volatiles=self._volatiles)
def is_stage_advancing(self, global_step: int) -> bool:
old_stage_id = self._stage_id.numpy()
new_stage_id = self.compute_stage_id(global_step)
return old_stage_id != new_stage_id
def update_pt_stage(self, global_step: int, pass_old_model=True) -> None:
"""Update progressive training internal status.
Call this after a training loop ends.
Args:
global_step: an integer scalar of the current global step.
pass_old_model: whether to pass the old_model to get_model() function.
This is set to False if the old_model is irrelevant (e.g, just a default
model from stage 0).
"""
old_stage_id = self._stage_id.numpy()
new_stage_id = self.compute_stage_id(global_step)
logging.info('Switching stage from %d to %d', old_stage_id, new_stage_id)
# Update stage id.
self._stage_id.assign(new_stage_id)
# Update dataset function.
self._cur_train_dataset = None
self._cur_eval_dataset = None
# Update optimizer and model.
new_optimizer = self.get_optimizer(new_stage_id)
self._volatiles.reassign_trackable(optimizer=new_optimizer)
new_model = self.get_model(
new_stage_id, old_model=self.cur_model if pass_old_model else None)
self._volatiles.reassign_trackable(model=new_model)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM binary for the progressive trainer."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import flags as tfm_flags
from official.core import task_factory
from official.core import train_utils
from official.modeling import performance
from official.modeling.fast_training.progressive import train_lib
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu,
**params.runtime.model_parallelism())
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM progressive training driver library.
Compared to the common training driver, the only difference is that we use
prog_trainer_lib.ProgressiveTrainer instead of the base trainer.
"""
# pytype: disable=attribute-error
import os
from typing import Any, Mapping, Tuple
# Import libraries
from absl import logging
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.core import train_lib as base_train_lib
from official.modeling.fast_training.progressive import trainer as prog_trainer_lib
def run_experiment(distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True) \
-> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
with distribution_strategy.scope():
logging.info('Running progressive trainer.')
trainer = prog_trainer_lib.ProgressiveTrainer(
params, task, ckpt_dir=model_dir,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=base_train_lib.maybe_create_best_ckpt_exporter(
params, model_dir))
if trainer.checkpoint:
checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=trainer.global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize)
else:
checkpoint_manager = None
controller = orbit.Controller(
strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None,
evaluator=trainer,
global_step=trainer.global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None,
eval_summary_dir=os.path.join(model_dir, 'validation') if
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if trainer.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
if run_post_eval:
with distribution_strategy.scope():
return trainer.model, trainer.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps))
else:
return trainer.model, {}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the progressive train_lib."""
import os
from absl import flags
from absl.testing import parameterized
import dataclasses
import orbit
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.common import flags as tfm_flags
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.core import config_definitions as cfg
from official.core import task_factory
from official.modeling import optimization
from official.modeling.hyperparams import params_dict
from official.modeling.fast_training.progressive import policies
from official.modeling.fast_training.progressive import train_lib
from official.modeling.fast_training.progressive import trainer as prog_trainer_lib
from official.utils.testing import mock_task
FLAGS = flags.FLAGS
tfm_flags.define_flags()
@dataclasses.dataclass
class ProgTaskConfig(cfg.TaskConfig):
pass
@task_factory.register_task_cls(ProgTaskConfig)
class ProgMockTask(policies.ProgressivePolicy, mock_task.MockTask):
"""Progressive task for testing."""
def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
mock_task.MockTask.__init__(
self, params=params, logging_dir=logging_dir)
policies.ProgressivePolicy.__init__(self)
def num_stages(self):
return 2
def num_steps(self, stage_id):
return 2 if stage_id == 0 else 4
def get_model(self, stage_id, old_model=None):
del stage_id, old_model
return self.build_model()
def get_optimizer(self, stage_id):
"""Build optimizer for each stage."""
params = optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
},
'learning_rate': {
'type': 'polynomial',
'polynomial': {
'initial_learning_rate': 0.01,
'end_learning_rate': 0.0,
'power': 1.0,
'decay_steps': 10,
},
},
'warmup': {
'polynomial': {
'power': 1,
'warmup_steps': 2,
},
'type': 'polynomial',
}
})
opt_factory = optimization.OptimizerFactory(params)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
return optimizer
def get_train_dataset(self, stage_id):
del stage_id
strategy = tf.distribute.get_strategy()
return orbit.utils.make_distributed_dataset(
strategy, self.build_inputs, None)
def get_eval_dataset(self, stage_id):
del stage_id
strategy = tf.distribute.get_strategy()
return orbit.utils.make_distributed_dataset(
strategy, self.build_inputs, None)
class TrainTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainTest, self).setUp()
self._test_config = {
'trainer': {
'checkpoint_interval': 10,
'steps_per_loop': 10,
'summary_interval': 10,
'train_steps': 10,
'validation_steps': 5,
'validation_interval': 10,
'continuous_eval_timeout': 1,
'optimizer_config': {
'optimizer': {
'type': 'sgd',
},
'learning_rate': {
'type': 'constant'
}
}
},
}
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'eval', 'train_and_eval'],
run_post_eval=[True, False]))
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
model_dir = self.get_temp_dir()
experiment_config = cfg.ExperimentConfig(
trainer=prog_trainer_lib.ProgressiveTrainerConfig(),
task=ProgTaskConfig())
experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope():
task = task_factory.get_task(experiment_config.task,
logging_dir=model_dir)
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=experiment_config,
model_dir=model_dir,
run_post_eval=run_post_eval)
if run_post_eval:
self.assertNotEmpty(logs)
else:
self.assertEmpty(logs)
if flag_mode == 'eval':
return
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=experiment_config,
model_dir=model_dir,
run_post_eval=run_post_eval)
print(logs)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Progressive Trainer implementation.
The trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
import dataclasses
import os
from typing import Any, Optional
# Import libraries
from absl import logging
import gin
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import base_trainer as trainer_lib
from official.core import config_definitions
from official.modeling.fast_training.progressive import policies
from official.modeling.fast_training.progressive import utils
ExperimentConfig = config_definitions.ExperimentConfig
@dataclasses.dataclass
class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
"""Configuration for progressive trainer.
Attributes:
progressive: A task-specific config. Users can subclass ProgressiveConfig
and define any task-specific settings in their subclass.
export_checkpoint: A bool. Whether to export checkpoints in non-progressive
manner (without the volatiles wrapper) such that your down-stream tasks
can load checkpoints from a progressive trainer as if it is a regular
checkpoint.
export_checkpoint_interval: A bool. The number of steps between exporting
checkpoints. If None (by default), will use the same value as
TrainerConfig.checkpoint_interval.
export_max_to_keep: The maximum number of exported checkpoints to keep.
If None (by default), will use the same value as
TrainerConfig.max_to_keep.
export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
during the final progressive training stage. In other words, whether to
not export small, partial models. In many cases, it is not meaningful to
finetune a small, partial model in down-stream tasks.
"""
progressive: Optional[policies.ProgressiveConfig] = None
export_checkpoint: bool = True
export_checkpoint_interval: Optional[int] = None
export_max_to_keep: Optional[int] = None
export_only_final_stage_ckpt: bool = True
@gin.configurable
class ProgressiveTrainer(trainer_lib.Trainer):
"""Implements the progressive trainer shared for TensorFlow models."""
def __init__(
self,
config: ExperimentConfig,
prog_task: base_task.Task, # also implemented ProgressivePolicy.
ckpt_dir: str = '',
train: bool = True,
evaluate: bool = True,
checkpoint_exporter: Any = None):
"""Initialize common trainer for TensorFlow models.
Args:
config: An `ExperimentConfig` instance specifying experiment config.
prog_task: An instance both implemented policies.ProgressivePolicy and
base_task.Task.
ckpt_dir: Checkpoint directory.
train: bool, whether or not this trainer will be used for training.
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._config = config
self._runtime_options = trainer_lib.get_runtime_options(config)
self._task = prog_task
# Directory for non-progressive checkpoint
self._export_ckpt_dir = os.path.join(ckpt_dir, 'exported_ckpts')
tf.io.gfile.makedirs(self._export_ckpt_dir)
self._export_ckpt_manager = None
# Receive other checkpoint export, e.g, best checkpoint exporter.
# TODO(lehou): unify the checkpoint exporting logic, although the default
# setting does not use checkpoint_exporter.
self._checkpoint_exporter = checkpoint_exporter
self._global_step = orbit.utils.create_global_step()
self._checkpoint = utils.CheckpointWithHooks(
before_load_hook=self._update_pt_stage_from_ckpt,
global_step=self.global_step,
**self._task.cur_checkpoint_items)
self._train_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
self._validation_loss = tf.keras.metrics.Mean(
'validation_loss', dtype=tf.float32)
self._train_metrics = self.task.build_metrics(
training=True) + self.model.metrics
self._validation_metrics = self.task.build_metrics(
training=False) + self.model.metrics
if train:
orbit.StandardTrainer.__init__(
self,
None, # Manage train_dataset by ourselves, not by StandardTrainer.
options=orbit.StandardTrainerOptions(
use_tf_while_loop=config.trainer.train_tf_while_loop,
use_tf_function=config.trainer.train_tf_function))
if evaluate:
orbit.StandardEvaluator.__init__(
self,
None, # Manage train_dataset by ourselves, not by StandardEvaluator.
options=orbit.StandardEvaluatorOptions(
use_tf_function=config.trainer.eval_tf_function))
@property
def model(self):
return self._task.cur_model
@property
def optimizer(self):
return self._task.cur_optimizer
# override
@property
def train_dataset(self):
"""Overriding StandardTrainer.train_dataset."""
return self._task.cur_train_dataset
# override
@train_dataset.setter
def train_dataset(self, _):
raise SyntaxError('Please do not set train_dataset. Progressive training '
'relies on progressive policy to manager train dataset.')
# override
@property
def eval_dataset(self):
"""Overriding StandardEvaluator.eval_dataset."""
return self._task.cur_eval_dataset
# override
@eval_dataset.setter
def eval_dataset(self, _):
raise SyntaxError('Please do not set eval_dataset. Progressive training '
'relies on progressive policy to manager eval dataset.')
def train_loop_end(self):
"""See base class."""
logs = {}
for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result()
metric.reset_states()
if callable(self.optimizer.learning_rate):
logs['learning_rate'] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
logs['learning_rate'] = self.optimizer.learning_rate
self._maybe_export_non_progressive_checkpoint(self._export_ckpt_dir)
if self._task.is_stage_advancing(self.global_step.numpy()):
old_train_dataset = self.train_dataset
# Update progressive properties
self._task.update_pt_stage(self.global_step.numpy())
# Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
# rebuild the train and eval functions with the updated model.
self._train_loop_fn = None
self._eval_loop_fn = None
if self.train_dataset != old_train_dataset:
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self._export_ckpt_manager = None
return logs
def _update_pt_stage_from_ckpt(self, ckpt_file):
"""Update stage properties based on the global_step variable in a ckpt file.
Before loading variables from a checkpoint file, we need to go to the
correct stage and build corresponding model and optimizer, to make sure that
we retore variables of the right model and optimizer.
Args:
ckpt_file: Checkpoint file that will be restored/read from.
"""
if not ckpt_file:
return
ckpt = tf.train.Checkpoint(global_step=self.global_step)
ckpt.read(ckpt_file).expect_partial().assert_existing_objects_matched()
if self._task.is_stage_advancing(self.global_step.numpy()):
old_train_dataset = self.train_dataset
# Update progressive properties
self._task.update_pt_stage(self.global_step.numpy(), pass_old_model=False)
# Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
# rebuild the train and eval functions with the updated model.
self._train_loop_fn = None
self._eval_loop_fn = None
if self.train_dataset != old_train_dataset:
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self._train_iter = None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self._export_ckpt_manager = None
def _maybe_export_non_progressive_checkpoint(self, export_ckpt_dir):
"""Export checkpoints in non-progressive format.
This basically removes the wrapping of self._task.cur_checkpoint_items
-- just save the model, optimizer, etc., directly.
The purpose is to let your down-stream tasks to use these checkpoints.
Args:
export_ckpt_dir: A str. folder of exported checkpoints.
"""
if not self.config.trainer.export_checkpoint:
logging.info('Not exporting checkpoints.')
return
if not self._task.is_last_stage and (
self.config.trainer.export_only_final_stage_ckpt):
logging.info('Not exporting checkpoints until the last stage.')
return
if self._export_ckpt_manager is None:
# Create a checkpoint object just now, to make sure we use
# progressive_policy.cur_model and progressive_policy.cur_optimizer of the
# current stage.
if hasattr(self.model, 'checkpoint_items'):
checkpoint_items = self.model.checkpoint_items
else:
checkpoint_items = {}
checkpoint = tf.train.Checkpoint(
global_step=self.global_step,
model=self.model,
optimizer=self.optimizer,
**checkpoint_items)
max_to_keep = self.config.trainer.export_max_to_keep or (
self.config.trainer.max_to_keep)
checkpoint_interval = self.config.trainer.export_checkpoint_interval or (
self.config.trainer.checkpoint_interval)
self._export_ckpt_manager = tf.train.CheckpointManager(
checkpoint,
directory=export_ckpt_dir,
checkpoint_name='ckpt',
step_counter=self.global_step,
max_to_keep=max_to_keep,
checkpoint_interval=checkpoint_interval,
)
# Make sure we export the last checkpoint.
last_checkpoint = (
self.global_step.numpy() == self._config.trainer.train_steps)
checkpoint_path = self._export_ckpt_manager.save(
checkpoint_number=self.global_step.numpy(),
check_interval=not last_checkpoint)
if checkpoint_path:
logging.info('Checkpoints exported: %s.', checkpoint_path)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the progressive trainer."""
# pylint: disable=g-direct-tensorflow-import
import os
from absl.testing import parameterized
import orbit
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import config_definitions as cfg
from official.modeling import optimization
from official.modeling.fast_training.progressive import policies
from official.modeling.fast_training.progressive import trainer as trainer_lib
from official.nlp.configs import bert
from official.utils.testing import mock_task
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
def get_exp_config():
return cfg.ExperimentConfig(
task=cfg.TaskConfig(
model=bert.PretrainerConfig()),
trainer=trainer_lib.ProgressiveTrainerConfig(
export_checkpoint=True,
export_checkpoint_interval=1,
export_only_final_stage_ckpt=False))
class TestPolicy(policies.ProgressivePolicy, mock_task.MockTask):
"""Just for testing purposes."""
def __init__(self, strategy, task_config, change_train_dataset=True):
self._strategy = strategy
self._change_train_dataset = change_train_dataset
self._my_train_dataset = None
mock_task.MockTask.__init__(self, params=task_config, logging_dir=None)
policies.ProgressivePolicy.__init__(self)
def num_stages(self) -> int:
return 2
def num_steps(self, stage_id: int) -> int:
return 2 if stage_id == 0 else 4
def get_model(self,
stage_id: int,
old_model: tf.keras.Model) -> tf.keras.Model:
del stage_id, old_model
return self.build_model()
def get_optimizer(self, stage_id: int) -> tf.keras.optimizers.Optimizer:
optimizer_type = 'sgd' if stage_id == 0 else 'adamw'
optimizer_config = cfg.OptimizationConfig({
'optimizer': {'type': optimizer_type},
'learning_rate': {'type': 'constant'}})
opt_factory = optimization.OptimizerFactory(optimizer_config)
return opt_factory.build_optimizer(opt_factory.build_learning_rate())
def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
if not self._change_train_dataset and self._my_train_dataset:
return self._my_train_dataset
if self._strategy:
self._my_train_dataset = orbit.utils.make_distributed_dataset(
self._strategy,
self._build_inputs,
stage_id)
else:
self._my_train_dataset = self._build_inputs(stage_id)
return self._my_train_dataset
def get_eval_dataset(self, stage_id: int) -> tf.data.Dataset:
if self._strategy:
return orbit.utils.make_distributed_dataset(
self._strategy,
self._build_inputs,
stage_id)
return self._build_inputs(stage_id)
def _build_inputs(self, stage_id):
def dummy_data(_):
batch_size = 2 if stage_id == 0 else 1
x = tf.zeros(shape=(batch_size, 2), dtype=tf.float32)
label = tf.zeros(shape=(batch_size, 1), dtype=tf.float32)
return x, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
return dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainerTest, self).setUp()
self._config = get_exp_config()
def create_test_trainer(self, distribution, model_dir, change_train_dataset):
trainer = trainer_lib.ProgressiveTrainer(
self._config,
prog_task=TestPolicy(
distribution, self._config.task, change_train_dataset),
ckpt_dir=model_dir)
return trainer
@combinations.generate(all_strategy_combinations())
def test_checkpointing(self, distribution):
model_dir = self.get_temp_dir()
ckpt_file = os.path.join(model_dir, 'ckpt')
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
self.assertFalse(trainer._task.is_last_stage)
trainer.train(tf.convert_to_tensor(4, dtype=tf.int32))
self.assertTrue(trainer._task.is_last_stage)
trainer.checkpoint.save(ckpt_file)
trainer = self.create_test_trainer(distribution, model_dir, True)
self.assertFalse(trainer._task.is_last_stage)
trainer.checkpoint.restore(ckpt_file + '-1')
self.assertTrue(trainer._task.is_last_stage)
@combinations.generate(all_strategy_combinations())
def test_train_dataset(self, distribution):
model_dir = self.get_temp_dir()
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
# Using dataset of stage == 0
train_iter = tf.nest.map_structure(iter, trainer.train_dataset)
train_data = train_iter.next()[0]
if distribution.num_replicas_in_sync > 1:
train_data = train_data.values[0]
self.assertEqual(train_data.shape[0], 2)
trainer.train(tf.convert_to_tensor(4, dtype=tf.int32))
# Using dataset of stage == 1
train_iter = tf.nest.map_structure(iter, trainer.train_dataset)
train_data = train_iter.next()[0]
if distribution.num_replicas_in_sync > 1:
train_data = train_data.values[0]
self.assertEqual(train_data.shape[0], 1)
with self.assertRaises(SyntaxError):
trainer.train_dataset = None
@combinations.generate(all_strategy_combinations())
def test_train_dataset_no_switch(self, distribution):
model_dir = self.get_temp_dir()
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, False)
trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
# _train_iter is not reset since the dataset is not changed.
self.assertIsNotNone(trainer._train_iter)
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
# _train_iter is reset since the dataset changed.
self.assertIsNone(trainer._train_iter)
class TrainerWithMaskedLMTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainerWithMaskedLMTaskTest, self).setUp()
self._config = get_exp_config()
def create_test_trainer(self, distribution):
trainer = trainer_lib.ProgressiveTrainer(
self._config,
prog_task=TestPolicy(distribution, self._config.task),
ckpt_dir=self.get_temp_dir())
return trainer
@combinations.generate(all_strategy_combinations())
def test_trainer_train(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(distribution)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_validate(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(distribution)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('validation_loss', logs)
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
@combinations.generate(
combinations.combine(
mixed_precision_dtype=['float32', 'bfloat16', 'float16'],
loss_scale=[None, 'dynamic', 128, 256],
))
def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
config = cfg.ExperimentConfig(
task=cfg.TaskConfig(
model=bert.PretrainerConfig()),
runtime=cfg.RuntimeConfig(
mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
trainer=trainer_lib.ProgressiveTrainerConfig(
export_checkpoint=True,
export_checkpoint_interval=1,
export_only_final_stage_ckpt=False))
task = TestPolicy(None, config.task)
trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir())
if mixed_precision_dtype != 'float16':
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Util classes and functions."""
from absl import logging
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.training.tracking import tracking
class VolatileTrackable(tracking.AutoTrackable):
"""A util class to keep Trackables that might change instances."""
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def reassign_trackable(self, **kwargs):
for k, v in kwargs.items():
delattr(self, k) # untrack this object
setattr(self, k, v) # track the new object
class CheckpointWithHooks(tf.train.Checkpoint):
"""Same as tf.train.Checkpoint but supports hooks.
In progressive training, use this class instead of tf.train.Checkpoint.
Since the network architecture changes during progressive training, we need to
prepare something (like switch to the correct architecture) before loading the
checkpoint. This class supports a hook that will be executed before checkpoint
loading.
"""
def __init__(self, before_load_hook, **kwargs):
self._before_load_hook = before_load_hook
super(CheckpointWithHooks, self).__init__(**kwargs)
# override
def read(self, save_path, options=None):
self._before_load_hook(save_path)
logging.info('Ran before_load_hook.')
super(CheckpointWithHooks, self).read(save_path=save_path, options=options)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Some gradient util functions to help users writing custom training loop."""
from absl import logging
import tensorflow as tf
def _filter_grads(grads_and_vars):
"""Filter out iterable with grad equal to None."""
grads_and_vars = tuple(grads_and_vars)
if not grads_and_vars:
return grads_and_vars
filtered = []
vars_with_empty_grads = []
for grad, var in grads_and_vars:
if grad is None:
vars_with_empty_grads.append(var)
else:
filtered.append((grad, var))
filtered = tuple(filtered)
if not filtered:
raise ValueError("No gradients provided for any variable: %s." %
([v.name for _, v in grads_and_vars],))
if vars_with_empty_grads:
logging.warning(
("Gradients do not exist for variables %s when minimizing the loss."),
([v.name for v in vars_with_empty_grads]))
return filtered
def _filter_and_allreduce_gradients(grads_and_vars,
allreduce_precision="float32",
bytes_per_pack=0):
"""Filter None grads and then allreduce gradients in specified precision.
This utils function is used when users intent to explicitly allreduce
gradients and customize gradients operations before and after allreduce.
The allreduced gradients are then passed to optimizer.apply_gradients(
experimental_aggregate_gradients=False).
Args:
grads_and_vars: gradients and variables pairs.
allreduce_precision: Whether to allreduce gradients in float32 or float16.
bytes_per_pack: A non-negative integer. Breaks collective operations into
packs of certain size. If it's zero, all gradients are in one pack.
Returns:
pairs of allreduced non-None gradients and variables.
"""
filtered_grads_and_vars = _filter_grads(grads_and_vars)
(grads, variables) = zip(*filtered_grads_and_vars)
if allreduce_precision == "float16":
grads = [tf.cast(grad, "float16") for grad in grads]
hints = tf.distribute.experimental.CommunicationOptions(
bytes_per_pack=bytes_per_pack)
allreduced_grads = tf.distribute.get_strategy( # pylint: disable=protected-access
).extended._replica_ctx_all_reduce(tf.distribute.ReduceOp.SUM, grads, hints)
if allreduce_precision == "float16":
allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads]
return allreduced_grads, variables
def _run_callbacks(callbacks, grads_and_vars):
for callback in callbacks:
grads_and_vars = callback(grads_and_vars)
return grads_and_vars
def minimize_using_explicit_allreduce(tape,
optimizer,
loss,
trainable_variables,
pre_allreduce_callbacks=None,
post_allreduce_callbacks=None,
allreduce_bytes_per_pack=0):
"""Minimizes loss for one step by updating `trainable_variables`.
Minimizes loss for one step by updating `trainable_variables`.
This explicitly performs gradient allreduce, instead of relying on implicit
allreduce in optimizer.apply_gradients(). If training using FP16 mixed
precision, explicit allreduce will aggregate gradients in FP16 format.
For TPU and GPU training using FP32, explicit allreduce will aggregate
gradients in FP32 format.
Args:
tape: An instance of `tf.GradientTape`.
optimizer: An instance of `tf.keras.optimizers.Optimizer`.
loss: the loss tensor.
trainable_variables: A list of model Variables.
pre_allreduce_callbacks: A list of callback functions that takes gradients
and model variables pairs as input, manipulate them, and returns a new
gradients and model variables pairs. The callback functions will be
invoked in the list order and before gradients are allreduced. With
mixed precision training, the pre_allreduce_allbacks will be applied on
scaled_gradients. Default is no callbacks.
post_allreduce_callbacks: A list of callback functions that takes
gradients and model variables pairs as input, manipulate them, and
returns a new gradients and model variables paris. The callback
functions will be invoked in the list order and right before gradients
are applied to variables for updates. Default is no callbacks.
allreduce_bytes_per_pack: A non-negative integer. Breaks collective
operations into packs of certain size. If it's zero, all gradients are
in one pack.
"""
if isinstance(optimizer,
tf.keras.mixed_precision.LossScaleOptimizer):
# FP16 GPU code path
with tape:
scaled_loss = optimizer.get_scaled_loss(loss)
scaled_grads = tape.gradient(scaled_loss, trainable_variables)
grads_and_vars = zip(scaled_grads, trainable_variables)
if pre_allreduce_callbacks:
grads_and_vars = _run_callbacks(pre_allreduce_callbacks, grads_and_vars)
(allreduced_scaled_grads,
filtered_training_vars) = _filter_and_allreduce_gradients(
grads_and_vars,
allreduce_precision="float16",
bytes_per_pack=allreduce_bytes_per_pack)
allreduced_unscaled_grads = optimizer.get_unscaled_gradients(
allreduced_scaled_grads)
grads_and_vars = zip(allreduced_unscaled_grads, filtered_training_vars)
else:
# TPU or FP32 GPU code path
grads = tape.gradient(loss, trainable_variables)
grads_and_vars = zip(grads, trainable_variables)
if pre_allreduce_callbacks:
grads_and_vars = _run_callbacks(pre_allreduce_callbacks, grads_and_vars)
(allreduced_grads,
filtered_training_vars) = _filter_and_allreduce_gradients(
grads_and_vars,
allreduce_precision="float32",
bytes_per_pack=allreduce_bytes_per_pack)
grads_and_vars = zip(allreduced_grads, filtered_training_vars)
if post_allreduce_callbacks:
grads_and_vars = _run_callbacks(post_allreduce_callbacks, grads_and_vars)
optimizer.apply_gradients(
grads_and_vars, experimental_aggregate_gradients=False)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparams package definition."""
# pylint: disable=g-multiple-import
from official.modeling.hyperparams.base_config import *
from official.modeling.hyperparams.oneof import *
from official.modeling.hyperparams.params_dict import *
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base configurations to standardize experiments."""
import copy
import dataclasses
import functools
import inspect
from typing import Any, List, Mapping, Optional, Type
from absl import logging
import tensorflow as tf
import yaml
from official.modeling.hyperparams import params_dict
_BOUND = set()
def bind(config_cls):
"""Bind a class to config cls."""
if not inspect.isclass(config_cls):
raise ValueError('The bind decorator is supposed to apply on the class '
f'attribute. Received {config_cls}, not a class.')
def decorator(builder):
if config_cls in _BOUND:
raise ValueError('Inside a program, we should not bind the config with a'
' class twice.')
if inspect.isclass(builder):
config_cls._BUILDER = builder # pylint: disable=protected-access
elif inspect.isfunction(builder):
def _wrapper(self, *args, **kwargs): # pylint: disable=unused-argument
return builder(*args, **kwargs)
config_cls._BUILDER = _wrapper # pylint: disable=protected-access
else:
raise ValueError(f'The `BUILDER` type is not supported: {builder}')
_BOUND.add(config_cls)
return builder
return decorator
@dataclasses.dataclass
class Config(params_dict.ParamsDict):
"""The base configuration class that supports YAML/JSON based overrides.
Because of YAML/JSON serialization limitations, some semantics of dataclass
are not supported:
* It recursively enforces a allowlist of basic types and container types, so
it avoids surprises with copy and reuse caused by unanticipated types.
* Warning: it converts Dict to `Config` even within sequences,
e.g. for config = Config({'key': [([{'a': 42}],)]),
type(config.key[0][0][0]) is Config rather than dict.
If you define/annotate some field as Dict, the field will convert to a
`Config` instance and lose the dictionary type.
"""
# The class or method to bind with the params class.
_BUILDER = None
# It's safe to add bytes and other immutable types here.
IMMUTABLE_TYPES = (str, int, float, bool, type(None))
# It's safe to add set, frozenset and other collections here.
SEQUENCE_TYPES = (list, tuple)
default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None
restrictions: dataclasses.InitVar[Optional[List[str]]] = None
def __post_init__(self, default_params, restrictions):
super().__init__(
default_params=default_params,
restrictions=restrictions)
@property
def BUILDER(self):
return self._BUILDER
@classmethod
def _isvalidsequence(cls, v):
"""Check if the input values are valid sequences.
Args:
v: Input sequence.
Returns:
True if the sequence is valid. Valid sequence includes the sequence
type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
is dict or ParamsDict.
"""
if not isinstance(v, cls.SEQUENCE_TYPES):
return False
return (all(isinstance(e, cls.IMMUTABLE_TYPES) for e in v) or
all(isinstance(e, dict) for e in v) or
all(isinstance(e, params_dict.ParamsDict) for e in v))
@classmethod
def _import_config(cls, v, subconfig_type):
"""Returns v with dicts converted to Configs, recursively."""
if not issubclass(subconfig_type, params_dict.ParamsDict):
raise TypeError(
'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
subconfig_type))
if isinstance(v, cls.IMMUTABLE_TYPES):
return v
elif isinstance(v, cls.SEQUENCE_TYPES):
# Only support one layer of sequence.
if not cls._isvalidsequence(v):
raise TypeError(
'Invalid sequence: only supports single level {!r} of {!r} or '
'dict or ParamsDict found: {!r}'.format(cls.SEQUENCE_TYPES,
cls.IMMUTABLE_TYPES, v))
import_fn = functools.partial(
cls._import_config, subconfig_type=subconfig_type)
return type(v)(map(import_fn, v))
elif isinstance(v, params_dict.ParamsDict):
# Deepcopy here is a temporary solution for preserving type in nested
# Config object.
return copy.deepcopy(v)
elif isinstance(v, dict):
return subconfig_type(v)
else:
raise TypeError('Unknown type: {!r}'.format(type(v)))
@classmethod
def _export_config(cls, v):
"""Returns v with Configs converted to dicts, recursively."""
if isinstance(v, cls.IMMUTABLE_TYPES):
return v
elif isinstance(v, cls.SEQUENCE_TYPES):
return type(v)(map(cls._export_config, v))
elif isinstance(v, params_dict.ParamsDict):
return v.as_dict()
elif isinstance(v, dict):
raise TypeError('dict value not supported in converting.')
else:
raise TypeError('Unknown type: {!r}'.format(type(v)))
@classmethod
def _get_subconfig_type(cls, k) -> Type[params_dict.ParamsDict]:
"""Get element type by the field name.
Args:
k: the key/name of the field.
Returns:
Config as default. If a type annotation is found for `k`,
1) returns the type of the annotation if it is subtype of ParamsDict;
2) returns the element type if the annotation of `k` is List[SubType]
or Tuple[SubType].
"""
subconfig_type = Config
if k in cls.__annotations__:
# Directly Config subtype.
type_annotation = cls.__annotations__[k] # pytype: disable=invalid-annotation
if (isinstance(type_annotation, type) and
issubclass(type_annotation, Config)):
subconfig_type = cls.__annotations__[k] # pytype: disable=invalid-annotation
else:
# Check if the field is a sequence of subtypes.
field_type = getattr(type_annotation, '__origin__', type(None))
if (isinstance(field_type, type) and
issubclass(field_type, cls.SEQUENCE_TYPES)):
element_type = getattr(type_annotation, '__args__', [type(None)])[0]
subconfig_type = (
element_type if issubclass(element_type, params_dict.ParamsDict)
else subconfig_type)
return subconfig_type
def _set(self, k, v):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
k: key to set.
v: value.
Raises:
RuntimeError
"""
subconfig_type = self._get_subconfig_type(k)
def is_null(k):
if k not in self.__dict__ or not self.__dict__[k]:
return True
return False
if isinstance(v, dict):
if is_null(k):
# If the key not exist or the value is None, a new Config-family object
# sould be created for the key.
self.__dict__[k] = subconfig_type(v)
else:
self.__dict__[k].override(v)
elif not is_null(k) and isinstance(v, self.SEQUENCE_TYPES) and all(
[not isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
if len(self.__dict__[k]) == len(v):
for i in range(len(v)):
self.__dict__[k][i].override(v[i])
elif not all([isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
logging.warning(
"The list/tuple don't match the value dictionaries provided. Thus, "
'the list/tuple is determined by the type annotation and '
'values provided. This is error-prone.')
self.__dict__[k] = self._import_config(v, subconfig_type)
else:
self.__dict__[k] = self._import_config(v, subconfig_type)
else:
self.__dict__[k] = self._import_config(v, subconfig_type)
def __setattr__(self, k, v):
if k == 'BUILDER' or k == '_BUILDER':
raise AttributeError('`BUILDER` is a property and `_BUILDER` is the '
'reserved class attribute. We should only assign '
'`_BUILDER` at the class level.')
if k not in self.RESERVED_ATTR:
if getattr(self, '_locked', False):
raise ValueError('The Config has been locked. ' 'No change is allowed.')
self._set(k, v)
def _override(self, override_dict, is_strict=True):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
override_dict: dictionary to write to .
is_strict: If True, not allows to add new keys.
Raises:
KeyError: overriding reserved keys or keys not exist (is_strict=True).
"""
for k, v in sorted(override_dict.items()):
if k in self.RESERVED_ATTR:
raise KeyError('The key {!r} is internally reserved. '
'Can not be overridden.'.format(k))
if k not in self.__dict__:
if is_strict:
raise KeyError('The key {!r} does not exist in {!r}. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'.format(
k, type(self)))
else:
self._set(k, v)
else:
if isinstance(v, dict) and self.__dict__[k]:
self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
else:
self._set(k, v)
def as_dict(self):
"""Returns a dict representation of params_dict.ParamsDict.
For the nested params_dict.ParamsDict, a nested dict will be returned.
"""
return {
k: self._export_config(v)
for k, v in self.__dict__.items()
if k not in self.RESERVED_ATTR
}
def replace(self, **kwargs):
"""Overrides/returns a unlocked copy with the current config unchanged."""
# pylint: disable=protected-access
params = copy.deepcopy(self)
params._locked = False
params._override(kwargs, is_strict=True)
# pylint: enable=protected-access
return params
@classmethod
def from_yaml(cls, file_path: str):
# Note: This only works if the Config has all default values.
with tf.io.gfile.GFile(file_path, 'r') as f:
loaded = yaml.load(f, Loader=yaml.FullLoader)
config = cls()
config.override(loaded)
return config
@classmethod
def from_json(cls, file_path: str):
"""Wrapper for `from_yaml`."""
return cls.from_yaml(file_path)
@classmethod
def from_args(cls, *args, **kwargs):
"""Builds a config from the given list of arguments."""
attributes = list(cls.__annotations__.keys())
default_params = {a: p for a, p in zip(attributes, args)}
default_params.update(kwargs)
return cls(default_params=default_params)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pprint
from typing import List, Tuple
from absl.testing import parameterized
import dataclasses
import tensorflow as tf
from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class DumpConfig1(base_config.Config):
a: int = 1
b: str = 'text'
@dataclasses.dataclass
class DumpConfig2(base_config.Config):
c: int = 2
d: str = 'text'
e: DumpConfig1 = DumpConfig1()
@dataclasses.dataclass
class DumpConfig3(DumpConfig2):
f: int = 2
g: str = 'text'
h: List[DumpConfig1] = dataclasses.field(
default_factory=lambda: [DumpConfig1(), DumpConfig1()])
g: Tuple[DumpConfig1, ...] = (DumpConfig1(),)
@dataclasses.dataclass
class DumpConfig4(DumpConfig2):
x: int = 3
@dataclasses.dataclass
class DummyConfig5(base_config.Config):
y: Tuple[DumpConfig2, ...] = (DumpConfig2(), DumpConfig4())
z: Tuple[str] = ('a',)
class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
def assertHasSameTypes(self, c, d, msg=''):
"""Checks if a Config has the same structure as a given dict.
Args:
c: the Config object to be check.
d: the reference dict object.
msg: The error message to show when type mismatched.
"""
# Make sure d is not a Config. Assume d is either
# dictionary or primitive type and c is the Config or primitive types.
self.assertNotIsInstance(d, base_config.Config)
if isinstance(d, base_config.Config.IMMUTABLE_TYPES):
self.assertEqual(pprint.pformat(c), pprint.pformat(d), msg=msg)
elif isinstance(d, base_config.Config.SEQUENCE_TYPES):
self.assertEqual(type(c), type(d), msg=msg)
for i, v in enumerate(d):
self.assertHasSameTypes(c[i], v, msg='{}[{!r}]'.format(msg, i))
elif isinstance(d, dict):
self.assertIsInstance(c, base_config.Config, msg=msg)
for k, v in sorted(d.items()):
self.assertHasSameTypes(getattr(c, k), v, msg='{}[{!r}]'.format(msg, k))
else:
raise TypeError('Unknown type: %r' % type(d))
def assertImportExport(self, v):
config = base_config.Config({'key': v})
back = config.as_dict()['key']
self.assertEqual(pprint.pformat(back), pprint.pformat(v))
self.assertHasSameTypes(config.key, v, msg='=%s v' % pprint.pformat(v))
def test_invalid_keys(self):
params = base_config.Config()
with self.assertRaises(AttributeError):
_ = params.a
def test_cls(self):
params = base_config.Config()
with self.assertRaisesRegex(
AttributeError,
'`BUILDER` is a property and `_BUILDER` is the reserved'):
params.BUILDER = DumpConfig2
with self.assertRaisesRegex(
AttributeError,
'`BUILDER` is a property and `_BUILDER` is the reserved'):
params._BUILDER = DumpConfig2
base_config.bind(DumpConfig1)(DumpConfig2)
params = DumpConfig1()
self.assertEqual(params.BUILDER, DumpConfig2)
with self.assertRaisesRegex(ValueError,
'Inside a program, we should not bind'):
base_config.bind(DumpConfig1)(DumpConfig2)
def _test():
return 'test'
base_config.bind(DumpConfig2)(_test)
params = DumpConfig2()
self.assertEqual(params.BUILDER(), 'test')
def test_nested_config_types(self):
config = DumpConfig3()
self.assertIsInstance(config.e, DumpConfig1)
self.assertIsInstance(config.h[0], DumpConfig1)
self.assertIsInstance(config.h[1], DumpConfig1)
self.assertIsInstance(config.g[0], DumpConfig1)
config.override({'e': {'a': 2, 'b': 'new text'}})
self.assertIsInstance(config.e, DumpConfig1)
self.assertEqual(config.e.a, 2)
self.assertEqual(config.e.b, 'new text')
config.override({'h': [{'a': 3, 'b': 'new text 2'}]})
self.assertIsInstance(config.h[0], DumpConfig1)
self.assertLen(config.h, 1)
self.assertEqual(config.h[0].a, 3)
self.assertEqual(config.h[0].b, 'new text 2')
config.override({'g': [{'a': 4, 'b': 'new text 3'}]})
self.assertIsInstance(config.g[0], DumpConfig1)
self.assertLen(config.g, 1)
self.assertEqual(config.g[0].a, 4)
self.assertEqual(config.g[0].b, 'new text 3')
def test_replace(self):
config = DumpConfig2()
new_config = config.replace(e={'a': 2})
self.assertEqual(new_config.e.a, 2)
self.assertIsInstance(new_config.e, DumpConfig1)
config = DumpConfig2(e=DumpConfig2())
new_config = config.replace(e={'c': 4})
self.assertEqual(new_config.e.c, 4)
self.assertIsInstance(new_config.e, DumpConfig2)
config = DumpConfig3()
new_config = config.replace(g=[{'a': 4, 'b': 'new text 3'}])
self.assertIsInstance(new_config.g[0], DumpConfig1)
self.assertEqual(new_config.g[0].a, 4)
@parameterized.parameters(
('_locked', "The key '_locked' is internally reserved."),
('_restrictions', "The key '_restrictions' is internally reserved."),
('aa', "The key 'aa' does not exist."),
)
def test_key_error(self, key, msg):
params = base_config.Config()
with self.assertRaisesRegex(KeyError, msg):
params.override({key: True})
@parameterized.parameters(
('str data',),
(123,),
(1.23,),
(None,),
(['str', 1, 2.3, None],),
(('str', 1, 2.3, None),),
)
def test_import_export_immutable_types(self, v):
self.assertImportExport(v)
out = base_config.Config({'key': v})
self.assertEqual(pprint.pformat(v), pprint.pformat(out.key))
def test_override_is_strict_true(self):
params = base_config.Config({
'a': 'aa',
'b': 2,
'c': {
'c1': 'cc',
'c2': 20
}
})
params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c1, 'ccc')
with self.assertRaises(KeyError):
params.override({'d': 'ddd'}, is_strict=True)
with self.assertRaises(KeyError):
params.override({'c': {'c3': 30}}, is_strict=True)
config = base_config.Config({'key': [{'a': 42}]})
with self.assertRaisesRegex(KeyError, "The key 'b' does not exist"):
config.override({'key': [{'b': 43}]})
@parameterized.parameters(
(lambda x: x, 'Unknown type'),
(object(), 'Unknown type'),
(set(), 'Unknown type'),
(frozenset(), 'Unknown type'),
)
def test_import_unsupport_types(self, v, msg):
with self.assertRaisesRegex(TypeError, msg):
_ = base_config.Config({'key': v})
@parameterized.parameters(
({
'a': [{
'b': 2,
}, {
'c': 3,
}]
},),
({
'c': [{
'f': 1.1,
}, {
'h': [1, 2],
}]
},),
(({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20,
}
},),),
)
def test_import_export_nested_structure(self, d):
self.assertImportExport(d)
@parameterized.parameters(
([{
'a': 42,
'b': 'hello',
'c': 1.2
}],),
(({
'a': 42,
'b': 'hello',
'c': 1.2
},),),
)
def test_import_export_nested_sequences(self, v):
self.assertImportExport(v)
@parameterized.parameters(
([([{}],)],),
([['str', 1, 2.3, None]],),
((('str', 1, 2.3, None),),),
([
('str', 1, 2.3, None),
],),
([
('str', 1, 2.3, None),
],),
([[{
'a': 42,
'b': 'hello',
'c': 1.2
}]],),
([[[{
'a': 42,
'b': 'hello',
'c': 1.2
}]]],),
((({
'a': 42,
'b': 'hello',
'c': 1.2
},),),),
(((({
'a': 42,
'b': 'hello',
'c': 1.2
},),),),),
([({
'a': 42,
'b': 'hello',
'c': 1.2
},)],),
(([{
'a': 42,
'b': 'hello',
'c': 1.2
}],),),
)
def test_import_export_unsupport_sequence(self, v):
with self.assertRaisesRegex(TypeError,
'Invalid sequence: only supports single level'):
_ = base_config.Config({'key': v})
def test_construct_subtype(self):
pass
def test_import_config(self):
params = base_config.Config({'a': [{'b': 2}, {'c': {'d': 3}}]})
self.assertLen(params.a, 2)
self.assertEqual(params.a[0].b, 2)
self.assertEqual(type(params.a[0]), base_config.Config)
self.assertEqual(pprint.pformat(params.a[0].b), '2')
self.assertEqual(type(params.a[1]), base_config.Config)
self.assertEqual(type(params.a[1].c), base_config.Config)
self.assertEqual(pprint.pformat(params.a[1].c.d), '3')
def test_override(self):
params = base_config.Config({'a': [{'b': 2}, {'c': {'d': 3}}]})
params.override({'a': [{'b': 4}, {'c': {'d': 5}}]}, is_strict=False)
self.assertEqual(type(params.a), list)
self.assertEqual(type(params.a[0]), base_config.Config)
self.assertEqual(pprint.pformat(params.a[0].b), '4')
self.assertEqual(type(params.a[1]), base_config.Config)
self.assertEqual(type(params.a[1].c), base_config.Config)
self.assertEqual(pprint.pformat(params.a[1].c.d), '5')
@parameterized.parameters(
([{}],),
(({},),),
)
def test_config_vs_params_dict(self, v):
d = {'key': v}
self.assertEqual(type(base_config.Config(d).key[0]), base_config.Config)
self.assertEqual(type(base_config.params_dict.ParamsDict(d).key[0]), dict)
def test_ppformat(self):
self.assertEqual(
pprint.pformat([
's', 1, 1.0, True, None, {}, [], (), {
(2,): (3, [4], {
6: 7,
}),
8: 9,
}
]),
"['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]")
def test_with_restrictions(self):
restrictions = ['e.a<c']
config = DumpConfig2(restrictions=restrictions)
config.validate()
def test_nested_tuple(self):
config = DummyConfig5()
config.override({
'y': [{
'c': 4,
'd': 'new text 3',
'e': {
'a': 2
}
}, {
'c': 0,
'd': 'new text 3',
'e': {
'a': 2
}
}],
'z': ['a', 'b', 'c'],
})
self.assertEqual(config.y[0].c, 4)
self.assertEqual(config.y[1].c, 0)
self.assertIsInstance(config.y[0], DumpConfig2)
self.assertIsInstance(config.y[1], DumpConfig4)
self.assertSameElements(config.z, ['a', 'b', 'c'])
def test_override_by_empty_sequence(self):
config = DummyConfig5()
config.override({
'y': [],
'z': (),
}, is_strict=True)
self.assertEmpty(config.y)
self.assertEmpty(config.z)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Config class that supports oneof functionality."""
from typing import Optional
import dataclasses
from official.modeling.hyperparams import base_config
@dataclasses.dataclass
class OneOfConfig(base_config.Config):
"""Configuration for configs with one of feature.
Attributes:
type: 'str', name of the field to select.
"""
type: Optional[str] = None
def as_dict(self):
"""Returns a dict representation of OneOfConfig.
For the nested base_config.Config, a nested dict will be returned.
"""
if self.type is None:
return {'type': None}
elif self.__dict__['type'] not in self.__dict__:
raise ValueError('type: {!r} is not a valid key!'.format(
self.__dict__['type']))
else:
chosen_type = self.type
chosen_value = self.__dict__[chosen_type]
return {'type': self.type, chosen_type: self._export_config(chosen_value)}
def get(self):
"""Returns selected config based on the value of type.
If type is not set (None), None is returned.
"""
chosen_type = self.type
if chosen_type is None:
return None
if chosen_type not in self.__dict__:
raise ValueError('type: {!r} is not a valid key!'.format(self.type))
return self.__dict__[chosen_type]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import tensorflow as tf
from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import oneof
@dataclasses.dataclass
class ResNet(base_config.Config):
model_depth: int = 50
@dataclasses.dataclass
class Backbone(oneof.OneOfConfig):
type: str = 'resnet'
resnet: ResNet = ResNet()
not_resnet: int = 2
@dataclasses.dataclass
class OutputLayer(oneof.OneOfConfig):
type: str = 'single'
single: int = 1
multi_head: int = 2
@dataclasses.dataclass
class Network(base_config.Config):
backbone: Backbone = Backbone()
output_layer: OutputLayer = OutputLayer()
class OneOfTest(tf.test.TestCase):
def test_to_dict(self):
network_params = {
'backbone': {
'type': 'resnet',
'resnet': {
'model_depth': 50
}
},
'output_layer': {
'type': 'single',
'single': 1000
}
}
network_config = Network(network_params)
self.assertEqual(network_config.as_dict(), network_params)
def test_get_oneof(self):
backbone = Backbone()
self.assertIsInstance(backbone.get(), ResNet)
self.assertEqual(backbone.get().as_dict(), {'model_depth': 50})
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A parameter dictionary class which supports the nest structure."""
import collections
import copy
import re
import six
import tensorflow as tf
import yaml
# regex pattern that matches on key-value pairs in a comma-separated
# key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE = re.compile(
r"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s*
((?P<val>\'(.*?)\' # single quote
|
\"(.*?)\" # double quote
|
[^,\[]* # single value
|
\[[^\]]*\])) # list of values
($|,\s*)""", re.VERBOSE)
_CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)')
# Yaml loader with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.
LOADER = yaml.SafeLoader
LOADER.add_implicit_resolver(
'tag:yaml.org,2002:float',
re.compile(r'''
^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$''', re.X),
list('-+0123456789.'))
class ParamsDict(object):
"""A hyperparameter container class."""
RESERVED_ATTR = ['_locked', '_restrictions']
def __init__(self, default_params=None, restrictions=None):
"""Instantiate a ParamsDict.
Instantiate a ParamsDict given a set of default parameters and a list of
restrictions. Upon initialization, it validates itself by checking all the
defined restrictions, and raise error if it finds inconsistency.
Args:
default_params: a Python dict or another ParamsDict object including the
default parameters to initialize.
restrictions: a list of strings, which define a list of restrictions to
ensure the consistency of different parameters internally. Each
restriction string is defined as a binary relation with a set of
operators, including {'==', '!=', '<', '<=', '>', '>='}.
"""
self._locked = False
self._restrictions = []
if restrictions:
self._restrictions = restrictions
if default_params is None:
default_params = {}
self.override(default_params, is_strict=False)
def _set(self, k, v):
if isinstance(v, dict):
self.__dict__[k] = ParamsDict(v)
else:
self.__dict__[k] = copy.deepcopy(v)
def __setattr__(self, k, v):
"""Sets the value of the existing key.
Note that this does not allow directly defining a new key. Use the
`override` method with `is_strict=False` instead.
Args:
k: the key string.
v: the value to be used to set the key `k`.
Raises:
KeyError: if k is not defined in the ParamsDict.
"""
if k not in ParamsDict.RESERVED_ATTR:
if k not in self.__dict__.keys():
raise KeyError('The key `%{}` does not exist. '
'To extend the existing keys, use '
'`override` with `is_strict` = True.'.format(k))
if self._locked:
raise ValueError('The ParamsDict has been locked. '
'No change is allowed.')
self._set(k, v)
def __getattr__(self, k):
"""Gets the value of the existing key.
Args:
k: the key string.
Returns:
the value of the key.
Raises:
AttributeError: if k is not defined in the ParamsDict.
"""
if k not in self.__dict__.keys():
raise AttributeError('The key `{}` does not exist. '.format(k))
return self.__dict__[k]
def __contains__(self, key):
"""Implements the membership test operator."""
return key in self.__dict__
def get(self, key, value=None):
"""Accesses through built-in dictionary get method."""
return self.__dict__.get(key, value)
def __delattr__(self, k):
"""Deletes the key and removes its values.
Args:
k: the key string.
Raises:
AttributeError: if k is reserverd or not defined in the ParamsDict.
ValueError: if the ParamsDict instance has been locked.
"""
if k in ParamsDict.RESERVED_ATTR:
raise AttributeError(
'The key `{}` is reserved. No change is allowes. '.format(k))
if k not in self.__dict__.keys():
raise AttributeError('The key `{}` does not exist. '.format(k))
if self._locked:
raise ValueError('The ParamsDict has been locked. No change is allowed.')
del self.__dict__[k]
def override(self, override_params, is_strict=True):
"""Override the ParamsDict with a set of given params.
Args:
override_params: a dict or a ParamsDict specifying the parameters to be
overridden.
is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict. If
False, keys in `override_params` can be different from what is currently
defined in the ParamsDict. In this case, the ParamsDict will be extended
to include the new keys.
"""
if self._locked:
raise ValueError('The ParamsDict has been locked. No change is allowed.')
if isinstance(override_params, ParamsDict):
override_params = override_params.as_dict()
self._override(override_params, is_strict) # pylint: disable=protected-access
def _override(self, override_dict, is_strict=True):
"""The implementation of `override`."""
for k, v in six.iteritems(override_dict):
if k in ParamsDict.RESERVED_ATTR:
raise KeyError('The key `%{}` is internally reserved. '
'Can not be overridden.')
if k not in self.__dict__.keys():
if is_strict:
raise KeyError('The key `{}` does not exist. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'.format(k))
else:
self._set(k, v)
else:
if isinstance(v, dict):
self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
elif isinstance(v, ParamsDict):
self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
else:
self.__dict__[k] = copy.deepcopy(v)
def lock(self):
"""Makes the ParamsDict immutable."""
self._locked = True
def as_dict(self):
"""Returns a dict representation of ParamsDict.
For the nested ParamsDict, a nested dict will be returned.
"""
params_dict = {}
for k, v in six.iteritems(self.__dict__):
if k not in ParamsDict.RESERVED_ATTR:
if isinstance(v, ParamsDict):
params_dict[k] = v.as_dict()
else:
params_dict[k] = copy.deepcopy(v)
return params_dict
def validate(self):
"""Validate the parameters consistency based on the restrictions.
This method validates the internal consistency using the pre-defined list of
restrictions. A restriction is defined as a string which specfiies a binary
operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
'>='}. Note that the meaning of these operators are consistent with the
underlying Python immplementation. Users should make sure the define
restrictions on their type make sense.
For example, for a ParamsDict like the following
```
a:
a1: 1
a2: 2
b:
bb:
bb1: 10
bb2: 20
ccc:
a1: 1
a3: 3
```
one can define two restrictions like this
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
What it enforces are:
- a.a1 = 1 == b.ccc.a1 = 1
- a.a2 = 2 <= b.bb.bb2 = 20
Raises:
KeyError: if any of the following happens
(1) any of parameters in any of restrictions is not defined in
ParamsDict,
(2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported.
"""
def _get_kv(dotted_string, params_dict):
"""Get keys and values indicated by dotted_string."""
if _CONST_VALUE_RE.match(dotted_string) is not None:
const_str = dotted_string
if const_str == 'None':
constant = None
else:
constant = float(const_str)
return None, constant
else:
tokenized_params = dotted_string.split('.')
v = params_dict
for t in tokenized_params:
v = v[t]
return tokenized_params[-1], v
def _get_kvs(tokens, params_dict):
if len(tokens) != 2:
raise ValueError('Only support binary relation in restriction.')
stripped_tokens = [t.strip() for t in tokens]
left_k, left_v = _get_kv(stripped_tokens[0], params_dict)
right_k, right_v = _get_kv(stripped_tokens[1], params_dict)
return left_k, left_v, right_k, right_v
params_dict = self.as_dict()
for restriction in self._restrictions:
if '==' in restriction:
tokens = restriction.split('==')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v != right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '!=' in restriction:
tokens = restriction.split('!=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v == right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<' in restriction:
tokens = restriction.split('<')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v >= right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '<=' in restriction:
tokens = restriction.split('<=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v > right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>' in restriction:
tokens = restriction.split('>')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v <= right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
elif '>=' in restriction:
tokens = restriction.split('>=')
_, left_v, _, right_v = _get_kvs(tokens, params_dict)
if left_v < right_v:
raise KeyError(
'Found inconsistncy between key `{}` and key `{}`.'.format(
tokens[0], tokens[1]))
else:
raise ValueError('Unsupported relation in restriction.')
def read_yaml_to_params_dict(file_path: str):
"""Reads a YAML file to a ParamsDict."""
with tf.io.gfile.GFile(file_path, 'r') as f:
params_dict = yaml.load(f, Loader=LOADER)
return ParamsDict(params_dict)
def save_params_dict_to_yaml(params, file_path):
"""Saves the input ParamsDict to a YAML file."""
with tf.io.gfile.GFile(file_path, 'w') as f:
def _my_list_rep(dumper, data):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return dumper.represent_sequence(
u'tag:yaml.org,2002:seq', data, flow_style=True)
yaml.add_representer(list, _my_list_rep)
yaml.dump(params.as_dict(), f, default_flow_style=False)
def nested_csv_str_to_json_str(csv_str):
"""Converts a nested (using '.') comma-separated k=v string to a JSON string.
Converts a comma-separated string of key/value pairs that supports
nesting of keys to a JSON string. Nesting is implemented using
'.' between levels for a given key.
Spacing between commas and = is supported (e.g. there is no difference between
"a=1,b=2", "a = 1, b = 2", or "a=1, b=2") but there should be no spaces before
keys or after values (e.g. " a=1,b=2" and "a=1,b=2 " are not supported).
Note that this will only support values supported by CSV, meaning
values such as nested lists (e.g. "a=[[1,2,3],[4,5,6]]") are not
supported. Strings are supported as well, e.g. "a='hello'".
An example conversion would be:
"a=1, b=2, c.a=2, c.b=3, d.a.a=5"
to
"{ a: 1, b : 2, c: {a : 2, b : 3}, d: {a: {a : 5}}}"
Args:
csv_str: the comma separated string.
Returns:
the converted JSON string.
Raises:
ValueError: If csv_str is not in a comma separated string or
if the string is formatted incorrectly.
"""
if not csv_str:
return ''
formatted_entries = []
nested_map = collections.defaultdict(list)
pos = 0
while pos < len(csv_str):
m = _PARAM_RE.match(csv_str, pos)
if not m:
raise ValueError('Malformed hyperparameter value while parsing '
'CSV string: %s' % csv_str[pos:])
pos = m.end()
# Parse the values.
m_dict = m.groupdict()
name = m_dict['name']
v = m_dict['val']
# If a GCS path (e.g. gs://...) is provided, wrap this in quotes
# as yaml.load would otherwise throw an exception
if re.match(r'(?=[^\"\'])(?=[gs://])', v):
v = '\'{}\''.format(v)
name_nested = name.split('.')
if len(name_nested) > 1:
grouping = name_nested[0]
value = '.'.join(name_nested[1:]) + '=' + v
nested_map[grouping].append(value)
else:
formatted_entries.append('%s : %s' % (name, v))
for grouping, value in nested_map.items():
value = ','.join(value)
value = nested_csv_str_to_json_str(value)
formatted_entries.append('%s : %s' % (grouping, value))
return '{' + ', '.join(formatted_entries) + '}'
def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
"""Override a given ParamsDict using a dict, JSON/YAML/CSV string or YAML file.
The logic of the function is outlined below:
1. Test that the input is a dict. If not, proceed to 2.
2. Tests that the input is a string. If not, raise unknown ValueError
2.1. Test if the string is in a CSV format. If so, parse.
If not, proceed to 2.2.
2.2. Try loading the string as a YAML/JSON. If successful, parse to
dict and use it to override. If not, proceed to 2.3.
2.3. Try using the string as a file path and load the YAML file.
Args:
params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or path to
a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not.
Returns:
params: the overridden ParamsDict object.
Raises:
ValueError: if failed to override the parameters.
"""
if not dict_or_string_or_yaml_file:
return params
if isinstance(dict_or_string_or_yaml_file, dict):
params.override(dict_or_string_or_yaml_file, is_strict)
elif isinstance(dict_or_string_or_yaml_file, six.string_types):
try:
dict_or_string_or_yaml_file = (
nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
except ValueError:
pass
params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=LOADER)
if isinstance(params_dict, dict):
params.override(params_dict, is_strict)
else:
with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f:
params.override(yaml.load(f, Loader=yaml.FullLoader), is_strict)
else:
raise ValueError('Unknown input type to parse.')
return params
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for params_dict.py."""
import os
import tensorflow as tf
import yaml
from official.modeling.hyperparams import params_dict
class ParamsDictTest(tf.test.TestCase):
def test_init_from_an_empty_dict(self):
params = params_dict.ParamsDict()
with self.assertRaises(AttributeError):
_ = params.a
with self.assertRaises(KeyError):
params.a = 'aa'
def test_init_from_a_dict(self):
params = params_dict.ParamsDict({'a': 'aa', 'b': 2})
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
def test_init_from_a_param_dict(self):
params_init = params_dict.ParamsDict({'a': 'aa', 'b': 2})
params = params_dict.ParamsDict(params_init)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
def test_lock(self):
params = params_dict.ParamsDict({'a': 1, 'b': 2, 'c': 3})
params.lock()
with self.assertRaises(ValueError):
params.a = 10
with self.assertRaises(ValueError):
params.override({'b': 20})
with self.assertRaises(ValueError):
del params.c
def test_setattr(self):
params = params_dict.ParamsDict()
params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
params.c = 'ccc'
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c, 'ccc')
def test_getattr(self):
params = params_dict.ParamsDict()
params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c, None)
def test_delattr(self):
params = params_dict.ParamsDict()
params.override({
'a': 'aa',
'b': 2,
'c': None,
'd': {
'd1': 1,
'd2': 10
}
},
is_strict=False)
del params.c
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
with self.assertRaises(AttributeError):
_ = params.c
del params.d
with self.assertRaises(AttributeError):
_ = params.d.d1
def test_contains(self):
params = params_dict.ParamsDict()
params.override({'a': 'aa'}, is_strict=False)
self.assertIn('a', params)
self.assertNotIn('b', params)
def test_get(self):
params = params_dict.ParamsDict()
params.override({'a': 'aa'}, is_strict=False)
self.assertEqual(params.get('a'), 'aa')
self.assertEqual(params.get('b', 2), 2)
self.assertEqual(params.get('b'), None)
def test_override_is_strict_true(self):
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 'cc',
'c2': 20
}
})
params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c1, 'ccc')
with self.assertRaises(KeyError):
params.override({'d': 'ddd'}, is_strict=True)
with self.assertRaises(KeyError):
params.override({'c': {'c3': 30}}, is_strict=True)
def test_override_is_strict_false(self):
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
params.override({'a': 2, 'c': {'c3': 3000}}, is_strict=False)
self.assertEqual(params.a, 2)
self.assertEqual(params.c.c3, 3000)
params.override({'d': 'ddd'}, is_strict=False)
self.assertEqual(params.d, 'ddd')
params.override({'c': {'c4': 4444}}, is_strict=False)
self.assertEqual(params.c.c4, 4444)
def test_as_dict(self):
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
params_d = params.as_dict()
self.assertEqual(params_d['a'], 'aa')
self.assertEqual(params_d['b'], 2)
self.assertEqual(params_d['c']['c1'], 10)
self.assertEqual(params_d['c']['c2'], 20)
def test_validate(self):
# Raise error due to the unknown parameter.
with self.assertRaises(KeyError):
params = params_dict.ParamsDict({'a': 1, 'b': {'a': 11}}, ['a == c'])
params.validate()
# OK to check equality of two nested dicts.
params = params_dict.ParamsDict({
'a': 1,
'b': {
'a': 10
},
'c': {
'a': 10
}
}, ['b == c'])
# Raise error due to inconsistency
with self.assertRaises(KeyError):
params = params_dict.ParamsDict({'a': 1, 'c': {'a': 10}}, ['a == c.a'])
params.validate()
# Valid rule.
params = params_dict.ParamsDict({'a': 1, 'c': {'a': 1}}, ['a == c.a'])
# Overridding violates the existing rule, raise error upon validate.
params.override({'a': 11})
with self.assertRaises(KeyError):
params.validate()
# Valid restrictions with constant.
params = params_dict.ParamsDict({
'a': None,
'c': {
'a': 1
}
}, ['a == None', 'c.a == 1'])
params.validate()
with self.assertRaises(KeyError):
params = params_dict.ParamsDict({
'a': 4,
'c': {
'a': 1
}
}, ['a == None', 'c.a == 1'])
params.validate()
class ParamsDictIOTest(tf.test.TestCase):
def write_temp_file(self, filename, text):
temp_file = os.path.join(self.get_temp_dir(), filename)
with tf.io.gfile.GFile(temp_file, 'w') as writer:
writer.write(text)
return temp_file
def test_save_params_dict_to_yaml(self):
params = params_dict.ParamsDict({
'a': 'aa',
'b': 2,
'c': {
'c1': 10,
'c2': 20
}
})
output_yaml_file = os.path.join(self.get_temp_dir(), 'params.yaml')
params_dict.save_params_dict_to_yaml(params, output_yaml_file)
with tf.io.gfile.GFile(output_yaml_file, 'r') as f:
params_d = yaml.load(f)
self.assertEqual(params.a, params_d['a'])
self.assertEqual(params.b, params_d['b'])
self.assertEqual(params.c.c1, params_d['c']['c1'])
self.assertEqual(params.c.c2, params_d['c']['c2'])
def test_read_yaml_to_params_dict(self):
input_yaml_file = self.write_temp_file(
'params.yaml', r"""
a: 'aa'
b: 2
c:
c1: 10
c2: 20
""")
params = params_dict.read_yaml_to_params_dict(input_yaml_file)
self.assertEqual(params.a, 'aa')
self.assertEqual(params.b, 2)
self.assertEqual(params.c.c1, 10)
self.assertEqual(params.c.c2, 20)
def test_override_params_dict_using_dict(self):
params = params_dict.ParamsDict({
'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_dict = {'b': 5.2, 'c': [30, 40]}
params = params_dict.override_params_dict(
params, override_dict, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(5.2, params.b)
self.assertEqual([30, 40], params.c)
self.assertEqual('hello', params.d)
self.assertEqual(False, params.e)
def test_override_params_dict_using_yaml_string(self):
params = params_dict.ParamsDict({
'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_yaml_string = "'b': 5.2\n'c': [30, 40]"
params = params_dict.override_params_dict(
params, override_yaml_string, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(5.2, params.b)
self.assertEqual([30, 40], params.c)
self.assertEqual('hello', params.d)
self.assertEqual(False, params.e)
def test_override_params_dict_using_json_string(self):
params = params_dict.ParamsDict({
'a': 1,
'b': {
'b1': 2,
'b2': [2, 3],
},
'd': {
'd1': {
'd2': 'hello'
}
},
'e': False
})
override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params = params_dict.override_params_dict(
params, override_json_string, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(2, params.b.b1)
self.assertEqual([3, 4], params.b.b2)
self.assertEqual('hi', params.d.d1.d2)
self.assertEqual(False, params.e)
def test_override_params_dict_using_csv_string(self):
params = params_dict.ParamsDict({
'a': 1,
'b': {
'b1': 2,
'b2': [2, 3],
},
'd': {
'd1': {
'd2': 'hello'
}
},
'e': False
})
override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(2, params.b.b1)
self.assertEqual([3, 4], params.b.b2)
self.assertEqual('hi, world', params.d.d1.d2)
self.assertEqual('gs://test', params.e)
# Test different float formats
override_csv_string = 'b.b2=-1.e-3, d.d1.d2=+0.001, e=1e+3, a=-1.5E-3'
params = params_dict.override_params_dict(
params, override_csv_string, is_strict=True)
self.assertEqual(-1e-3, params.b.b2)
self.assertEqual(0.001, params.d.d1.d2)
self.assertEqual(1e3, params.e)
self.assertEqual(-1.5e-3, params.a)
def test_override_params_dict_using_yaml_file(self):
params = params_dict.ParamsDict({
'a': 1,
'b': 2.5,
'c': [3, 4],
'd': 'hello',
'e': False
})
override_yaml_file = self.write_temp_file(
'params.yaml', r"""
b: 5.2
c: [30, 40]
""")
params = params_dict.override_params_dict(
params, override_yaml_file, is_strict=True)
self.assertEqual(1, params.a)
self.assertEqual(5.2, params.b)
self.assertEqual([30, 40], params.c)
self.assertEqual('hello', params.d)
self.assertEqual(False, params.e)
class IOTest(tf.test.TestCase):
def test_basic_csv_str_to_json_str(self):
csv_str = 'a=1,b=2,c=3'
json_str = '{a : 1, b : 2, c : 3}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_basic_csv_str_load(self):
csv_str = 'a=1,b=2,c=3'
expected_output = {'a': 1, 'b': 2, 'c': 3}
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertDictEqual(converted_dict, expected_output)
def test_basic_nested_csv_str_to_json_str(self):
csv_str = 'a=1,b.b1=2'
json_str = '{a : 1, b : {b1 : 2}}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_basic_nested_csv_str_load(self):
csv_str = 'a=1,b.b1=2,c.c1=3'
expected_output = {'a': 1, 'b': {'b1': 2}, 'c': {'c1': 3}}
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertDictEqual(converted_dict, expected_output)
def test_complex_nested_csv_str_to_json_str(self):
csv_str = 'a.aa.aaa.aaaaa.a=1'
json_str = '{a : {aa : {aaa : {aaaaa : {a : 1}}}}}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, json_str)
def test_complex_nested_csv_str_load(self):
csv_str = 'a.aa.aaa.aaaaa.a=1,a.a=2'
expected_output = {'a': {'aa': {'aaa': {'aaaaa': {'a': 1}}}, 'a': 2}}
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertDictEqual(converted_dict, expected_output)
def test_csv_str_load_supported_datatypes(self):
csv_str = 'a=1,b=2.,c=[1,2,3],d=\'hello, there\',e=\"Hi.\"'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
converted_dict = yaml.load(converted_csv_str)
self.assertEqual(converted_dict['a'], 1)
self.assertEqual(converted_dict['b'], 2.)
self.assertEqual(converted_dict['c'], [1, 2, 3])
self.assertEqual(converted_dict['d'], 'hello, there')
self.assertEqual(converted_dict['e'], 'Hi.')
def test_csv_str_load_unsupported_datatypes(self):
csv_str = 'a=[[1,2,3],[4,5,6]]'
self.assertRaises(ValueError, params_dict.nested_csv_str_to_json_str,
csv_str)
def test_csv_str_to_json_str_spacing(self):
csv_str1 = 'a=1,b=2,c=3'
csv_str2 = 'a = 1, b = 2, c = 3'
json_str = '{a : 1, b : 2, c : 3}'
converted_csv_str1 = params_dict.nested_csv_str_to_json_str(csv_str1)
converted_csv_str2 = params_dict.nested_csv_str_to_json_str(csv_str2)
self.assertEqual(converted_csv_str1, converted_csv_str2)
self.assertEqual(converted_csv_str1, json_str)
self.assertEqual(converted_csv_str2, json_str)
def test_gcs_added_quotes(self):
csv_str = 'a=gs://abc, b=gs://def'
expected_output = '{a : \'gs://abc\', b : \'gs://def\'}'
converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
self.assertEqual(converted_csv_str, expected_output)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstraction of multi-task model."""
from typing import Text, Dict
import tensorflow as tf
class MultiTaskBaseModel(tf.Module):
"""Base class that holds multi-task model computation."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._sub_tasks = self._instantiate_sub_tasks()
def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
"""Abstract function that sets up the computation for each sub-task.
Returns:
A map from task name (as string) to a tf.keras.Model object that
represents the sub-task in the multi-task pool.
"""
raise NotImplementedError(
"_instantiate_sub_task_models() is not implemented.")
@property
def sub_tasks(self):
"""Fetch a map of task name (string) to task model (tf.keras.Model)."""
return self._sub_tasks
def initialize(self):
"""Optional function that loads a pre-train checkpoint."""
return
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask base trainer implementation.
The trainer derives from the Orbit `StandardTrainer` class.
"""
from typing import Union
import gin
import orbit
import tensorflow as tf
from official.modeling import optimization
from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
@gin.configurable
class MultiTaskBaseTrainer(orbit.StandardTrainer):
"""Multitask base trainer."""
def __init__(self,
multi_task: multitask.MultiTask,
multi_task_model: Union[tf.keras.Model,
base_model.MultiTaskBaseModel],
optimizer: tf.optimizers.Optimizer,
trainer_options=None,
train_datasets=None):
self._strategy = tf.distribute.get_strategy()
self._multi_task = multi_task
self._multi_task_model = multi_task_model
self._optimizer = optimizer
self._training_losses = None
self._training_metrics = None
self._global_step = orbit.utils.create_global_step()
# Creates a shadow copy of the weights to store weights moving average.
if isinstance(self._optimizer, optimization.ExponentialMovingAverage
) and not self._optimizer.has_shadow_copy:
self._optimizer.shadow_copy(multi_task_model)
if hasattr(self.multi_task_model, "checkpoint_items"):
checkpoint_items = self.multi_task_model.checkpoint_items
else:
checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint(
model=self.multi_task_model,
optimizer=self.optimizer,
global_step=self.global_step,
**checkpoint_items)
if train_datasets is None:
train_datasets = {}
for name, task in self.multi_task.tasks.items():
train_datasets[name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.train_data)
super().__init__(
train_dataset=train_datasets,
options=trainer_options or orbit.StandardTrainerOptions())
def train_loop_begin(self):
"""Clean up states that hold losses and metrics."""
for _, train_loss_metric in self.training_losses.items():
train_loss_metric.reset_states()
for _, metrics in self.training_metrics.items():
for metric in metrics:
metric.reset_states()
def train_loop_end(self):
"""Record loss and metric values per task."""
result = {}
for task_name, loss in self.training_losses.items():
result[task_name] = {loss.name: loss.result()}
for task_name, task_metrics in self.training_metrics.items():
result[task_name].update(
{metric.name: metric.result() for metric in task_metrics})
# Note that, the learning rate schedule is managed by the keras optimizer
# internally, which respects the number of backward pass as `iterations`.
# The learning rate schedule does not follow the trainer logical global
# step of multiple tasks.
if callable(self.optimizer.learning_rate):
result["learning_rate"] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
result["learning_rate"] = self.optimizer.learning_rate
return result
@property
def checkpoint(self):
"""Accesses the training checkpoint."""
return self._checkpoint
@property
def training_losses(self):
"""Access training loss metric objects for all tasks."""
if self._training_losses is None:
# Builds the per-task metrics and losses.
# This the total summed training loss of tasks in the joint training.
self._training_losses = dict(
total_loss=tf.keras.metrics.Mean("training_loss", dtype=tf.float32))
for name in self.multi_task.tasks:
self._training_losses[name] = tf.keras.metrics.Mean(
"training_loss", dtype=tf.float32)
return self._training_losses
@property
def training_metrics(self):
"""Access training metric metric objects for all tasks."""
if self._training_metrics is None:
# Builds the per-task metrics and losses.
self._training_metrics = {}
for name, task in self.multi_task.tasks.items():
self._training_metrics[name] = task.build_metrics(training=True)
return self._training_metrics
@property
def strategy(self):
return self._strategy
@property
def multi_task(self):
return self._multi_task
@property
def multi_task_model(self):
return self._multi_task_model
@property
def optimizer(self):
return self._optimizer
@property
def global_step(self):
return self._global_step
def train_step(self, iterator_map):
"""The default train step calling the multi-task train step.
Args:
iterator_map: a dictionary of task names and per-task dataset iterators.
"""
def step_fn(inputs):
losses = self.multi_task.joint_train_step(
inputs,
multi_task_model=self.multi_task_model,
optimizer=self.optimizer,
task_metrics=self.training_metrics)
for key, loss in losses.items():
self.training_losses[key].update_state(loss)
self.strategy.run(
step_fn, args=(tf.nest.map_structure(next, iterator_map),))
self.global_step.assign_add(1)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.base_trainer."""
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.modeling.multitask import base_trainer
from official.modeling.multitask import configs
from official.modeling.multitask import multitask
from official.modeling.multitask import test_utils
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
mode="eager",
)
class BaseTrainerTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_multitask_joint_trainer(self, distribution):
with distribution.scope():
tasks = [
test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
]
task_weights = {"foo": 1.0, "bar": 1.0}
test_multitask = multitask.MultiTask(
tasks=tasks, task_weights=task_weights)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
test_trainer = base_trainer.MultiTaskBaseTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
def test_trainer_with_configs(self):
config = configs.MultiTaskConfig(
task_routines=(configs.TaskRoutine(
task_name="foo",
task_config=test_utils.FooConfig(),
task_weight=0.5),
configs.TaskRoutine(
task_name="bar",
task_config=test_utils.BarConfig(),
task_weight=0.5)))
test_multitask = multitask.MultiTask.from_config(config)
test_optimizer = tf.keras.optimizers.SGD(0.1)
model = test_utils.MockMultiTaskModel()
test_trainer = base_trainer.MultiTaskBaseTrainer(
multi_task=test_multitask,
multi_task_model=model,
optimizer=test_optimizer)
results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertContainsSubset(["training_loss", "bar_acc"],
results["bar"].keys())
self.assertContainsSubset(["training_loss", "foo_acc"],
results["foo"].keys())
self.assertEqual(test_multitask.task_weight("foo"), 0.5)
self.assertEqual(test_trainer.global_step.numpy(), 5)
self.assertIn("learning_rate", results)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment