Commit 0ae03c71 authored by sunxx1's avatar sunxx1
Browse files

Merge branch 'main' into 'main'

update TensorFlow and TensorFlow2x test code

See merge request dcutoolkit/deeplearing/dlexamples_new!58
parents f270c43a a7666964
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
from official.nlp import tasks
from official.nlp.configs import experiment_configs
from official.utils.testing import mock_task
from official.vision import beta
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Global streamz counters."""
from tensorflow.python.eager import monitoring
progressive_policy_creation_counter = monitoring.Counter(
"/tensorflow/training/fast_training/progressive_policy_creation",
"Counter for the number of ProgressivePolicy creations.")
stack_vars_to_vars_call_counter = monitoring.Counter(
"/tensorflow/training/fast_training/tf_vars_to_vars",
"Counter for the number of low-level stacking API calls.")
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides TFM orbit actions and associated helper functions/classes."""
import os
from typing import List
from absl import logging
import gin
import orbit
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.core import base_trainer
from official.core import config_definitions
from official.modeling import optimization
class PruningActions:
"""Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log
pruning metrics to tensorboard.
This action must be used when training a pruned model to avoid pruning error.
"""
def __init__(
self,
export_dir: str,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the pruning summaries.
model: `tf.keras.Model` model instance used for training. This will be
used to assign a pruning step to each prunable weight.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to find the current training steps.
"""
self._optimizer = optimizer
self.update_pruning_step = tfmot.sparsity.keras.UpdatePruningStep()
self.update_pruning_step.set_model(model)
self.update_pruning_step.on_train_begin()
self.pruning_summaries = tfmot.sparsity.keras.PruningSummaries(
log_dir=export_dir)
model.optimizer = optimizer
self.pruning_summaries.set_model(model)
def __call__(self, output: orbit.runner.Output):
"""Update pruning step and log pruning summaries.
Args:
output: The train output to test.
"""
self.update_pruning_step.on_epoch_end(batch=None)
self.pruning_summaries.on_epoch_begin(epoch=None)
class EMACheckpointing:
"""Eval action to save checkpoint with average weights when EMA is used.
This action swaps the weights of the model with the average weights, then it
saves the checkpoint under export_dir/ema_checkpoints. Checkpointing is
expensive for large models, so doing this action in eval is more efficient
than training.
"""
def __init__(self, export_dir: str, optimizer: tf.keras.optimizers.Optimizer,
checkpoint: tf.train.Checkpoint, max_to_keep: int = 1):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the EMA average weights.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to swap the model weights with the average
weigths.
checkpoint: `tf.train.Checkpoint` instance.
max_to_keep: `int` for max checkpoints to keep in ema_checkpoints subdir.
"""
if not isinstance(optimizer, optimization.ExponentialMovingAverage):
raise ValueError('Optimizer has to be instance of'
'optimization.ExponentialMovingAverage for'
'EMACheckpointing action')
export_dir = os.path.join(export_dir, 'ema_checkpoints')
tf.io.gfile.makedirs(
os.path.dirname(export_dir))
self._optimizer = optimizer
self._checkpoint = checkpoint
self._checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=export_dir,
max_to_keep=max_to_keep,
checkpoint_name='average_weights')
def __call__(self, output: orbit.runner.Output):
"""Swaps model weights, and saves the checkpoint.
Args:
output: The train or eval output to test.
"""
self._optimizer.swap_weights()
self._checkpoint_manager.save(checkpoint_number=self._optimizer.iterations)
self._optimizer.swap_weights()
class RecoveryAction:
"""Train action to recover from loss blowup.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def __init__(self, checkpoint_manager: tf.train.CheckpointManager):
self.checkpoint_manager = checkpoint_manager
def __call__(self, _):
"""Recovers the training by triggering checkpoint restoration."""
# Loads the previous good checkpoint.
checkpoint_path = self.checkpoint_manager.restore_or_initialize()
logging.warning('Recovering the model from checkpoint: %s.',
checkpoint_path)
class RecoveryCondition:
"""Recovery Condition."""
def __init__(self,
global_step: tf.Variable,
loss_upper_bound: float,
recovery_begin_steps: int = 0,
recovery_max_trials: int = 3):
self.recover_counter = 0
self.recovery_begin_steps = recovery_begin_steps
self.recovery_max_trials = recovery_max_trials
self.loss_upper_bound = loss_upper_bound
self.global_step = global_step
def __call__(self, outputs: orbit.runner.Output):
loss_value = outputs['training_loss']
if tf.math.is_nan(loss_value):
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
'The loss value is NaN after training loop and it happens %d times.'
% self.recover_counter)
return True
if (self.global_step >= self.recovery_begin_steps and
loss_value > self.loss_upper_bound):
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
f'The loss value is {loss_value}, which is larger than the bound {self.loss_upper_bound}, happens {self.recover_counter} times.'
)
return True
return False
@gin.configurable
def get_eval_actions(
params: config_definitions.ExperimentConfig,
trainer: base_trainer.Trainer,
model_dir: str) -> List[orbit.Action]:
"""Gets eval actions for TFM trainer."""
eval_actions = []
# Adds ema checkpointing action to save the average weights under
# ema_checkpoints subdir.
if isinstance(trainer.optimizer, optimization.ExponentialMovingAverage):
eval_actions.append(
EMACheckpointing(
export_dir=model_dir,
optimizer=trainer.optimizer,
checkpoint=trainer.checkpoint,
max_to_keep=params.trainer.max_to_keep))
return eval_actions
@gin.configurable
def get_train_actions(
params: config_definitions.ExperimentConfig, trainer: base_trainer.Trainer,
model_dir: str,
checkpoint_manager: tf.train.CheckpointManager) -> List[orbit.Action]:
"""Gets train actions for TFM trainer."""
train_actions = []
# Adds pruning callback actions.
if hasattr(params.task, 'pruning'):
train_actions.append(
PruningActions(
export_dir=model_dir,
model=trainer.model,
optimizer=trainer.optimizer))
if params.trainer.recovery_max_trials >= 0:
recovery_condition = RecoveryCondition(
global_step=trainer.global_step,
loss_upper_bound=params.trainer.loss_upper_bound,
recovery_begin_steps=params.trainer.recovery_begin_steps,
recovery_max_trials=params.trainer.recovery_max_trials,
)
recover_action = orbit.actions.ConditionalAction(
condition=recovery_condition,
action=RecoveryAction(checkpoint_manager),
)
train_actions.append(recover_action)
return train_actions
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TFM actions."""
import os
from absl.testing import parameterized
import numpy as np
import orbit
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import actions
from official.modeling import optimization
class TestModel(tf.Module):
def __init__(self):
self.value = tf.Variable(0)
@tf.function(input_signature=[])
def __call__(self):
return self.value
class ActionsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],))
def test_ema_checkpointing(self, distribution):
with distribution.scope():
directory = self.create_tempdir()
model = TestModel()
optimizer = tf.keras.optimizers.SGD()
optimizer = optimization.ExponentialMovingAverage(
optimizer, trainable_weights_only=False)
# Creats average weights for the model variables. Average weights are
# initialized to zero.
optimizer.shadow_copy(model)
checkpoint = tf.train.Checkpoint(model=model)
# Changes model.value to 3, average value is still 0.
model.value.assign(3)
# Checks model.value is 3
self.assertEqual(model(), 3)
ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint)
ema_action({})
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints')))
checkpoint.read(tf.train.latest_checkpoint(
os.path.join(directory, 'ema_checkpoints')))
# Checks model.value is 0 after swapping.
self.assertEqual(model(), 0)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],))
def test_recovery_condition(self, distribution):
with distribution.scope():
global_step = orbit.utils.create_global_step()
recover_condition = actions.RecoveryCondition(
global_step, loss_upper_bound=0.5, recovery_max_trials=2)
outputs = {'training_loss': 0.6}
self.assertTrue(recover_condition(outputs))
self.assertTrue(recover_condition(outputs))
with self.assertRaises(RuntimeError):
recover_condition(outputs)
global_step = orbit.utils.create_global_step()
recover_condition = actions.RecoveryCondition(
global_step, loss_upper_bound=0.5, recovery_max_trials=2)
outputs = {'training_loss': tf.constant([np.nan], tf.float32)}
self.assertTrue(recover_condition(outputs))
self.assertTrue(recover_condition(outputs))
with self.assertRaises(RuntimeError):
recover_condition(outputs)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the base task abstraction."""
import abc
from typing import Optional
from absl import logging
import tensorflow as tf
from official.core import config_definitions
from official.modeling import optimization
from official.modeling import performance
OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig
class Task(tf.Module, metaclass=abc.ABCMeta):
"""A single-replica view of training procedure.
Tasks provide artifacts for training/validation procedures, including
loading/iterating over Datasets, training/validation steps, calculating the
loss and customized metrics with reduction.
"""
# Special keys in train/validate step returned logs.
loss = "loss"
def __init__(self,
params,
logging_dir: Optional[str] = None,
name: Optional[str] = None):
"""Task initialization.
Args:
params: the task configuration instance, which can be any of dataclass,
ConfigDict, namedtuple, etc.
logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory.
name: the task name.
"""
super().__init__(name=name)
self._task_config = params
self._logging_dir = logging_dir
@property
def task_config(self):
return self._task_config
@property
def logging_dir(self) -> str:
return self._logging_dir
@classmethod
def create_optimizer(cls, optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
Args:
optimizer_config: the parameters of the Optimization settings.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory = optimization.OptimizerFactory(optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if runtime_config:
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale)
return optimizer
def initialize(self, model: tf.keras.Model):
"""[Optional] A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint is found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. You can use this callback function to load a pretrained
checkpoint, saved under a directory other than the model_dir.
Args:
model: The keras.Model built or used by this task.
"""
ckpt_dir_or_file = self.task_config.init_checkpoint
logging.info("Trying to load pretrained checkpoint from %s",
ckpt_dir_or_file)
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
if hasattr(model, "checkpoint_items"):
checkpoint_items = model.checkpoint_items
else:
checkpoint_items = dict(model=model)
ckpt = tf.train.Checkpoint(**checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info("Finished loading pretrained checkpoint from %s",
ckpt_dir_or_file)
def build_model(self) -> tf.keras.Model:
"""[Optional] Creates model architecture.
Returns:
A model instance.
""" # pytype: disable=bad-return-type # typed-keras
@abc.abstractmethod
def build_inputs(self,
params,
input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a dataset or a nested structure of dataset functions.
Dataset functions define per-host datasets with the per-replica batch size.
With distributed training, this method runs on remote hosts.
Args:
params: hyperparams to create input pipelines, which can be any of
dataclass, ConfigDict, namedtuple, etc.
input_context: optional distribution input pipeline context.
Returns:
A nested structure of per-replica input functions.
"""
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Standard interface to compute losses.
Args:
labels: optional label tensors.
model_outputs: a nested structure of output tensors.
aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
del model_outputs, labels
if aux_losses is None:
losses = [tf.constant(0.0, dtype=tf.float32)]
else:
losses = aux_losses
total_loss = tf.add_n(losses)
return total_loss
def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation."""
del training
return []
def process_metrics(self, metrics, labels, model_outputs, **kwargs):
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects. The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
**kwargs: other args.
"""
for metric in metrics:
metric.update_state(labels, model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
"""Process and update compiled_metrics.
call when using compile/fit API.
Args:
compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
"""
compiled_metrics.update_state(labels, model_outputs)
def train_step(self,
inputs,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics=None):
"""Does forward and backward.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
if isinstance(inputs, tuple) and len(inputs) == 2:
features, labels = inputs
else:
features, labels = inputs, inputs
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Computes per-replica loss.
if model.compiled_loss:
loss = model.compiled_loss(
labels, outputs, regularization_losses=model.losses)
loss += self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=None)
else:
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
# For mixed precision, when a LossScaleOptimizer is used, the loss is
# scaled to avoid numeric underflow.
if isinstance(optimizer,
tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
if isinstance(optimizer,
tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
if model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics or []})
logs.update({m.name: m.result() for m in model.metrics})
return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validation step.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
if isinstance(inputs, tuple) and len(inputs) == 2:
features, labels = inputs
else:
features, labels = inputs, inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
if model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics or []})
logs.update({m.name: m.result() for m in model.metrics})
return logs
def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
Returns:
Model outputs.
"""
return model(inputs, training=False)
def aggregate_logs(self, state, step_logs):
"""Optional aggregation over logs returned from a validation step.
Given step_logs from a validation step, this function aggregates the logs
after each eval_step() (see eval_reduce() function in
official/core/base_trainer.py). It runs on CPU and can be used to aggregate
metrics during validation, when there are too many metrics that cannot fit
into TPU memory. Note that this may increase latency due to data transfer
between TPU and CPU. Also, the step output from a validation step may be a
tuple with elements from replicas, and a concatenation of the elements is
needed in such case.
Args:
state: The current state of training, for example, it can be a sequence of
metrics.
step_logs: Logs from a validation step. Can be a dictionary.
"""
pass
def reduce_aggregated_logs(self,
aggregated_logs,
global_step: Optional[tf.Tensor] = None):
"""Optional reduce of aggregated logs over validation steps.
This function reduces aggregated logs at the end of validation, and can be
used to compute the final metrics. It runs on CPU and in each eval_end() in
base trainer (see eval_end() function in official/core/base_trainer.py).
Args:
aggregated_logs: Aggregated logs over multiple validation steps.
global_step: An optional variable of global step.
Returns:
A dictionary of reduced results.
"""
return {}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Standard Trainer implementation.
The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
import functools
from typing import Union, Optional
from absl import logging
import gin
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.modeling import optimization
ExperimentConfig = config_definitions.ExperimentConfig
TrainerConfig = config_definitions.TrainerConfig
class Recovery:
"""Built-in model blowup recovery module.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def __init__(self,
loss_upper_bound: float,
checkpoint_manager: tf.train.CheckpointManager,
recovery_begin_steps: int = 0,
recovery_max_trials: int = 3):
self.recover_counter = 0
self.recovery_begin_steps = recovery_begin_steps
self.recovery_max_trials = recovery_max_trials
self.loss_upper_bound = loss_upper_bound
self.checkpoint_manager = checkpoint_manager
def should_recover(self, loss_value, global_step):
if tf.math.is_nan(loss_value):
return True
if (global_step >= self.recovery_begin_steps and
loss_value > self.loss_upper_bound):
return True
return False
def maybe_recover(self, loss_value, global_step):
"""Conditionally recovers the training by triggering checkpoint restoration.
Args:
loss_value: the loss value as a float.
global_step: the number of global training steps.
Raises:
RuntimeError: when recovery happens more than the max number of trials,
the job should crash.
"""
if not self.should_recover(loss_value, global_step):
return
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
"The loss value is NaN or out of range after training loop and "
f"this happens {self.recover_counter} times.")
# Loads the previous good checkpoint.
checkpoint_path = self.checkpoint_manager.restore_or_initialize()
logging.warning(
"Recovering the model from checkpoint: %s. The loss value becomes "
"%f at step %d.", checkpoint_path, loss_value, global_step)
class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
"""Trainer class for both sync and async Strategy."""
def init_async(self):
"""Initializes the Async Trainer base class."""
assert isinstance(self._strategy, tf.distribute.Strategy)
self._is_async = isinstance(
self._strategy, tf.distribute.experimental.ParameterServerStrategy)
self._coordinator = None
if self._is_async:
self._coordinator = (
tf.distribute.experimental.coordinator.ClusterCoordinator(
self._strategy))
def join(self):
"""Join all async steps. Only useful in aysnc training."""
if getattr(self, "_is_async", False):
self._coordinator.join()
def create_train_loop_fn(self):
"""Creates a eval loop from the given step function and options."""
train_loop_fn = super().create_train_loop_fn()
if getattr(self, "_is_async", False):
def _async_loop_fn(iterator, num_steps):
self._coordinator.schedule(train_loop_fn, args=(iterator, num_steps))
return _async_loop_fn
else:
return train_loop_fn
def create_eval_loop_fn(self, has_state: bool):
"""Creates a training loop from the given step function and options."""
eval_loop_fn = super().create_eval_loop_fn(has_state)
if getattr(self, "_is_async", False):
if has_state:
raise ValueError(
"Stateful eval loop is not supported in async training.")
def _async_loop_fn(iterator, num_steps, state=None, reduce_fn=None):
assert state is None
assert reduce_fn is None
self._coordinator.schedule(eval_loop_fn, args=(iterator, num_steps))
return _async_loop_fn
else:
return eval_loop_fn
def distribute_dataset(self, dataset_or_fn, *args, **kwargs):
"""A utility function to help create a `tf.distribute.DistributedDataset`.
Args:
dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function"
returning a `tf.data.Dataset`. If it is a function, it may optionally
have an argument named `input_context` which will be passed a
`tf.distribute.InputContext` instance.
*args: Any positional arguments to pass through to `dataset_or_fn`.
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
Returns:
A distributed Dataset.
"""
if getattr(self, "_is_async", False):
per_worker_dataset_fn = functools.partial(
orbit.utils.make_distributed_dataset, self._strategy, dataset_or_fn,
*args, **kwargs)
per_worker_dataset_fn = tf.function(per_worker_dataset_fn)
return self._coordinator.create_per_worker_dataset(per_worker_dataset_fn)
else:
return orbit.utils.make_distributed_dataset(self._strategy, dataset_or_fn,
*args, **kwargs)
def get_runtime_options(config: ExperimentConfig):
"""Get tf.distribute.RunOptions from config."""
xla_options = {}
if config.runtime.tpu_enable_xla_dynamic_padder is not None:
xla_options["enable_xla_dynamic_padder"] = (
config.runtime.tpu_enable_xla_dynamic_padder)
return tf.distribute.RunOptions(
experimental_xla_options=tf.tpu.XLAOptions(**xla_options))
@gin.configurable
class Trainer(_AsyncTrainer):
"""Implements the common trainer shared for TensorFlow models."""
# pylint: disable=super-init-not-called
def __init__(
self,
config: ExperimentConfig,
task: base_task.Task,
model: tf.keras.Model,
optimizer: tf.optimizers.Optimizer,
train: bool = True,
evaluate: bool = True,
train_dataset: Optional[Union[tf.data.Dataset,
tf.distribute.DistributedDataset]] = None,
validation_dataset: Optional[Union[
tf.data.Dataset, tf.distribute.DistributedDataset]] = None,
checkpoint_exporter=None):
"""Initialize common trainer for TensorFlow models.
Args:
config: An `ExperimentConfig` instance specifying experiment config.
task: A base_task.Task instance.
model: The model instance, e.g. a tf.keras.Model instance.
optimizer: tf.optimizers.Optimizer instance.
train: bool, whether or not this trainer will be used for training.
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
train_dataset: a dataset object created for training. With tf.distribute,
it needs to be a `DistributedDataset`.
validation_dataset: a dataset object created for evaluation. With
tf.distribute, it needs to be a `DistributedDataset`. The evaluator will
create a dataset iterator for each eval round, so the dataset does not
need to repeat.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._validate_params(
config,
check_train_data=train_dataset is None,
check_validation_data=validation_dataset is None)
self._config = config
self._task = task
self._model = model
self._optimizer = optimizer
self._checkpoint_exporter = checkpoint_exporter
self._recovery = None
# Runtime options are only applied to train_step.
# We use default for eval_step.
self._runtime_options = get_runtime_options(config)
# Creates a shadow copy of the weights to store weights moving average.
if isinstance(self._optimizer, optimization.ExponentialMovingAverage
) and not self._optimizer.has_shadow_copy:
self._optimizer.shadow_copy(self._model)
# global_step increases by 1 after each training iteration.
# We should have global_step.numpy() == self.optimizer.iterations.numpy()
# when there is only 1 optimizer.
self._global_step = orbit.utils.create_global_step()
if hasattr(self.model, "checkpoint_items"):
checkpoint_items = self.model.checkpoint_items
else:
checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step,
model=self.model,
optimizer=self.optimizer,
**checkpoint_items)
self._train_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32)
self._validation_loss = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32)
model_metrics = model.metrics if hasattr(model, "metrics") else []
self._train_metrics = self.task.build_metrics(
training=True) + model_metrics
self._validation_metrics = self.task.build_metrics(
training=False) + model_metrics
self.init_async()
if train:
train_dataset = train_dataset or self.distribute_dataset(
self.task.build_inputs, self.config.task.train_data)
orbit.StandardTrainer.__init__(
self,
train_dataset,
options=orbit.StandardTrainerOptions(
use_tf_while_loop=config.trainer.train_tf_while_loop,
use_tf_function=config.trainer.train_tf_function,
use_tpu_summary_optimization=config.trainer.allow_tpu_summary))
if evaluate:
validation_dataset = validation_dataset or self.distribute_dataset(
self.task.build_inputs, self.config.task.validation_data)
orbit.StandardEvaluator.__init__(
self,
validation_dataset,
options=orbit.StandardEvaluatorOptions(
use_tf_function=config.trainer.eval_tf_function,
use_tf_while_loop=config.trainer.eval_tf_while_loop))
def _validate_params(self,
config,
check_train_data=True,
check_validation_data=True):
r"""Validates if the configuration object passed to the Trainer.
The experiment configuration should be structured as:
\trainer
\task
\train_data
\validation_data
Args:
config: a namedtuple, dataclass, ConfigDict, etc.
check_train_data: whether to check task.train_data field.
check_validation_data: whether to check task.validation_data field.
"""
if not hasattr(config, "trainer"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `trainer`.")
if not hasattr(config, "task"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `task`.")
if check_train_data and not hasattr(config.task, "train_data"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `task.train_data`.")
if check_validation_data and not hasattr(config.task, "validation_data"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `task.validation_data`.")
@property
def strategy(self):
return self._strategy
@property
def config(self):
return self._config
@property
def task(self):
return self._task
@property
def model(self):
return self._model
@property
def optimizer(self):
if hasattr(self, "_optimizer"):
return self._optimizer
else:
return None
@property
def global_step(self):
return self._global_step
@property
def train_loss(self):
"""Accesses the training loss metric object."""
return self._train_loss
@property
def validation_loss(self):
"""Accesses the validation loss metric object."""
return self._validation_loss
@property
def train_metrics(self):
"""Accesses all training metric objects."""
return self._train_metrics
@property
def validation_metrics(self):
"""Accesses all validation metric metric objects."""
return self._validation_metrics
def initialize(self):
"""A callback function.
This function will be called when no checkpoint found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. Tasks may use this callback function to load a
pretrained checkpoint, saved under a directory other than the model_dir.
"""
self.task.initialize(self.model)
@property
def checkpoint(self):
"""Accesses the training checkpoint."""
return self._checkpoint
# TODO(yejiayu): Remove this once all deps are fixed.
def add_recovery(self, params: TrainerConfig,
checkpoint_manager: tf.train.CheckpointManager):
if params.recovery_max_trials >= 0:
self._recovery = Recovery(
loss_upper_bound=params.loss_upper_bound,
recovery_begin_steps=params.recovery_begin_steps,
recovery_max_trials=params.recovery_max_trials,
checkpoint_manager=checkpoint_manager)
def train_loop_end(self):
"""See base class."""
self.join()
logs = {}
for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result()
metric.reset_states()
if callable(self.optimizer.learning_rate):
# Maybe a self-implemented optimizer does not have `optimizer.iterations`.
# So just to be safe here.
if hasattr(self.optimizer, "iterations"):
logs["learning_rate"] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
logs["learning_rate"] = self.optimizer.learning_rate(self.global_step)
else:
logs["learning_rate"] = self.optimizer.learning_rate
return logs
def train_step(self, iterator):
"""See base class."""
def step_fn(inputs):
if self.config.runtime.enable_xla and (self.config.runtime.num_gpus > 0):
task_train_step = tf.function(self.task.train_step, jit_compile=True)
else:
task_train_step = self.task.train_step
logs = task_train_step(
inputs,
model=self.model,
optimizer=self.optimizer,
metrics=self.train_metrics)
self._train_loss.update_state(logs[self.task.loss])
self.global_step.assign_add(1)
self.strategy.run(
step_fn, args=(next(iterator),), options=self._runtime_options)
def eval_begin(self):
"""Sets up metrics."""
for metric in self.validation_metrics + [self.validation_loss]:
metric.reset_states()
# Swaps weights to test on weights moving average.
if self.optimizer and isinstance(self.optimizer,
optimization.ExponentialMovingAverage):
self.optimizer.swap_weights()
def eval_step(self, iterator):
"""See base class."""
def step_fn(inputs):
logs = self.task.validation_step(
inputs, model=self.model, metrics=self.validation_metrics)
if self.task.loss in logs:
self._validation_loss.update_state(logs[self.task.loss])
return logs
distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
return tf.nest.map_structure(self.strategy.experimental_local_results,
distributed_outputs)
def eval_end(self, aggregated_logs=None):
"""Processes evaluation results."""
self.join()
logs = {}
for metric in self.validation_metrics:
logs[metric.name] = metric.result()
if self.validation_loss.count.numpy() != 0:
logs[self.validation_loss.name] = self.validation_loss.result()
else:
# `self.validation_loss` metric was not updated, because the validation
# loss was not returned from the task's `validation_step` method.
logging.info("The task did not report validation loss.")
if aggregated_logs:
metrics = self.task.reduce_aggregated_logs(
aggregated_logs, global_step=self.global_step)
logs.update(metrics)
if self._checkpoint_exporter:
self._checkpoint_exporter.maybe_export_checkpoint(
self.checkpoint, logs, self.global_step.numpy())
metric_name = self.config.trainer.best_checkpoint_eval_metric
logs["best_" +
metric_name] = self._checkpoint_exporter.best_ckpt_logs[metric_name]
# Swaps back weights after testing when EMA is used.
# This happens after best checkpoint export so that average weights used for
# eval are exported instead of regular weights.
if self.optimizer and isinstance(self.optimizer,
optimization.ExponentialMovingAverage):
self.optimizer.swap_weights()
return logs
def eval_reduce(self, state=None, step_outputs=None):
return self.task.aggregate_logs(state, step_outputs)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
import gc
import multiprocessing
import os
import sys
from absl.testing import parameterized
import orbit
import portpicker
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import base_trainer as trainer_lib
from official.core import config_definitions as cfg
from official.core import train_lib
from official.utils.testing import mock_task
TPU_TEST = 'test_tpu' in sys.argv[0]
GPU_TEST = 'test_gpu' in sys.argv[0]
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
def create_in_process_cluster(num_workers, num_ps):
"""Creates and starts local servers and returns the cluster_resolver."""
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
cluster_dict = {}
cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
if num_ps > 0:
cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
cluster_spec = tf.train.ClusterSpec(cluster_dict)
# Workers need some inter_ops threads to work properly.
worker_config = tf.compat.v1.ConfigProto()
if multiprocessing.cpu_count() < num_workers + 1:
worker_config.inter_op_parallelism_threads = num_workers + 1
for i in range(num_workers):
tf.distribute.Server(
cluster_spec,
job_name='worker',
task_index=i,
config=worker_config,
protocol='grpc')
for i in range(num_ps):
tf.distribute.Server(
cluster_spec, job_name='ps', task_index=i, protocol='grpc')
cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
cluster_spec, rpc_layer='grpc')
return cluster_resolver
def dataset_fn(input_context=None):
del input_context
def dummy_data(_):
return tf.zeros((1, 1), dtype=tf.float32)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
class MockAsyncTrainer(trainer_lib._AsyncTrainer):
"""Mock AsyncTrainer to test the _AsyncTrainer class."""
def __init__(self):
self._strategy = tf.distribute.get_strategy()
self.init_async()
self.global_step = tf.Variable(
0,
dtype=tf.int64,
name='global_step',
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
self.eval_global_step = tf.Variable(
0,
dtype=tf.int64,
name='eval_global_step',
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
train_dataset = self.distribute_dataset(dataset_fn)
orbit.StandardTrainer.__init__(
self, train_dataset, options=orbit.StandardTrainerOptions())
validation_dataset = self.distribute_dataset(dataset_fn)
orbit.StandardEvaluator.__init__(
self,
validation_dataset,
options=orbit.StandardEvaluatorOptions(use_tf_while_loop=True))
def train_loop_begin(self):
self.global_step.assign(0)
def train_step(self, iterator):
def replica_step(_):
self.global_step.assign_add(1)
self._strategy.run(replica_step, args=(next(iterator),))
def train_loop_end(self):
self.join()
return self.global_step.numpy()
def eval_begin(self):
self.eval_global_step.assign(0)
def eval_step(self, iterator):
def replica_step(_):
self.eval_global_step.assign_add(1)
self._strategy.run(replica_step, args=(next(iterator),))
def eval_end(self):
self.join()
return self.eval_global_step.numpy()
class RecoveryTest(tf.test.TestCase):
def test_recovery_module(self):
ckpt = tf.train.Checkpoint(v=tf.Variable(1, dtype=tf.int32))
model_dir = self.get_temp_dir()
manager = tf.train.CheckpointManager(ckpt, model_dir, max_to_keep=1)
recovery_module = trainer_lib.Recovery(
loss_upper_bound=1.0,
checkpoint_manager=manager,
recovery_begin_steps=1,
recovery_max_trials=1)
self.assertFalse(recovery_module.should_recover(1.1, 0))
self.assertFalse(recovery_module.should_recover(0.1, 1))
self.assertTrue(recovery_module.should_recover(1.1, 2))
# First triggers the recovery once.
recovery_module.maybe_recover(1.1, 10)
# Second time, it raises.
with self.assertRaisesRegex(
RuntimeError, 'The loss value is NaN .*'):
recovery_module.maybe_recover(1.1, 10)
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
def tearDown(self):
gc.collect()
# This will only contain uncollectable garbage, i.e. reference cycles
# involving objects with __del__ defined.
self.assertEmpty(gc.garbage)
super().tearDown()
def create_test_trainer(self, config, model_dir=None, task=None):
task = task or mock_task.MockTask(config.task, logging_dir=model_dir)
ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
trainer = trainer_lib.Trainer(
config,
task,
model=task.build_model(),
optimizer=task.create_optimizer(config.trainer.optimizer_config,
config.runtime),
checkpoint_exporter=ckpt_exporter)
return trainer
@combinations.generate(all_strategy_combinations())
def test_trainer_train(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(self._config)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_passing_datasets(self, distribution):
with distribution.scope():
task = mock_task.MockTask(self._config)
train_dataset = orbit.utils.make_distributed_dataset(
distribution, task.build_inputs, self._config.task.train_data)
validation_dataset = orbit.utils.make_distributed_dataset(
distribution, task.build_inputs, self._config.task.validation_data)
self._config.task.train_data = None
self._config.task.validation_data = None
trainer = trainer_lib.Trainer(
self._config,
task,
model=task.build_model(),
optimizer=task.create_optimizer(self._config.trainer.optimizer_config,
self._config.runtime),
train_dataset=train_dataset,
validation_dataset=validation_dataset)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('validation_loss', logs)
def test_base_async_trainer(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.')
num_workers = 3
num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
distribution = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with distribution.scope():
trainer = MockAsyncTrainer()
trainer.init_async()
self.assertIsInstance(
trainer._coordinator,
tf.distribute.experimental.coordinator.ClusterCoordinator)
self.assertEqual(trainer.train(tf.constant(10)), 10)
self.assertEqual(trainer.evaluate(tf.constant(11)), 11)
def test_async_trainer_train(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/TPU.')
num_workers = 3
num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
distribution = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with distribution.scope():
config = cfg.ExperimentConfig(**self._config.as_dict())
config.trainer.eval_tf_while_loop = True
trainer = self.create_test_trainer(config)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
def test_async_trainer_validate(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.')
num_workers = 3
num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
distribution = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with distribution.scope():
config = cfg.ExperimentConfig(**self._config.as_dict())
config.trainer.eval_tf_while_loop = True
trainer = self.create_test_trainer(config)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('acc', logs)
self.assertIn('validation_loss', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_validate(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(self._config)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
self.assertIn('validation_loss', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_validate_without_loss(self, distribution):
class MockTaskWithoutValidationLoss(mock_task.MockTask):
def validation_step(self, inputs, model, metrics=None):
# Disable validation loss.
logs = super().validation_step(inputs, model)
del logs[self.loss]
return logs
with distribution.scope():
task = MockTaskWithoutValidationLoss()
trainer = self.create_test_trainer(self._config, task=task)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
self.assertNotIn('validation_loss', logs)
@combinations.generate(
combinations.combine(
mixed_precision_dtype=['float32', 'bfloat16', 'float16'],
loss_scale=[None, 'dynamic', 128, 256],
))
def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(
mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
trainer=cfg.TrainerConfig(
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
},
})))
trainer = self.create_test_trainer(config)
if mixed_precision_dtype == 'float16':
self.assertIsInstance(trainer.optimizer,
tf.keras.mixed_precision.LossScaleOptimizer)
if loss_scale in (None, 'dynamic'):
self.assertTrue(trainer.optimizer.dynamic)
else:
self.assertFalse(trainer.optimizer.dynamic)
self.assertEqual(trainer.optimizer.initial_scale, loss_scale)
else:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
def test_export_best_ckpt(self):
config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='acc',
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
model_dir = self.get_temp_dir()
trainer = self.create_test_trainer(config, model_dir=model_dir)
trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertTrue(
tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
def test_model_with_compiled_loss(self):
task = mock_task.MockTask()
model = task.build_model()
model.compile(loss=tf.keras.losses.CategoricalCrossentropy())
trainer = trainer_lib.Trainer(
self._config,
task,
model=model,
optimizer=task.create_optimizer(self._config.trainer.optimizer_config))
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common configuration settings."""
import dataclasses
from typing import Optional, Sequence, Union
from official.modeling.hyperparams import base_config
from official.modeling.optimization.configs import optimization_config
OptimizationConfig = optimization_config.OptimizationConfig
@dataclasses.dataclass
class DataConfig(base_config.Config):
"""The base configuration for building datasets.
Attributes:
input_path: The path to the input. It can be either (1) a str indicating a
file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of
str, each of which is a file path/pattern or multiple file paths/patterns
separated by comma, or (4) a dictionary of the previous three approaches
for more advanced data mixing using named access. It should not be
specified when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
is required when above `tfds_name` is specified.
global_batch_size: The global batch size across all replicas.
is_training: Whether this data is used for training or not. This flag is
useful for consumers of this object to determine whether the data should
be repeated or shuffled.
drop_remainder: Whether the last batch should be dropped in the case it has
fewer than `global_batch_size` elements.
shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. If `True`, we will cache the
dataset after applying the decode_fn and parse_fn. It can be used to avoid
re-reading from disk, re-decoding and re-parsing the example on the second
epoch, but it requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
deterministic: A boolean controlling whether determinism should be enforced.
sharding: Whether sharding is used in the input pipeline.
enable_tf_data_service: A boolean indicating whether to enable tf.data
service for the input pipeline.
tf_data_service_address: The URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This argument
makes it possible for multiple datasets to share the same job. The default
behavior is that the dataset creates anonymous, exclusively owned jobs.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default, the
returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped for
decoding when loading dataset from TFDS. Use comma to separate multiple
features. The main use case is to skip the image/video decoding for better
performance.
seed: An optional seed to use for deterministic shuffling/preprocessing.
"""
input_path: Union[Sequence[str], str, base_config.Config] = ""
tfds_name: str = ""
tfds_split: str = ""
global_batch_size: int = 0
is_training: bool = None
drop_remainder: bool = True
shuffle_buffer_size: int = 100
cache: bool = False
cycle_length: Optional[int] = None
block_length: int = 1
deterministic: Optional[bool] = None
sharding: bool = True
enable_tf_data_service: bool = False
tf_data_service_address: Optional[str] = None
tf_data_service_job_name: Optional[str] = None
tfds_data_dir: str = ""
tfds_as_supervised: bool = False
tfds_skip_decoding_feature: str = ""
seed: Optional[int] = None
@dataclasses.dataclass
class RuntimeConfig(base_config.Config):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
'float16', or 'bfloat16'.
loss_scale: The type of loss scale, or 'float' value. This is used when
setting the mixed precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
"""
distribution_strategy: str = "mirrored"
enable_xla: bool = False
gpu_thread_mode: Optional[str] = None
dataset_num_private_threads: Optional[int] = None
per_gpu_thread_count: int = 0
tpu: Optional[str] = None
num_gpus: int = 0
worker_hosts: Optional[str] = None
task_index: int = -1
all_reduce_alg: Optional[str] = None
num_packs: int = 1
mixed_precision_dtype: Optional[str] = None
loss_scale: Optional[Union[str, float]] = None
run_eagerly: bool = False
batchnorm_spatial_persistent: bool = False
# XLA runtime params.
# XLA params are only applied to the train_step.
# These augments can improve training speed. They can also improve eval, but
# may reduce usability and users would need to make changes to code.
# Whether to enable XLA dynamic padder
# infrastructure to handle dynamic shapes inputs inside XLA. True by
# default. Disabling this may cause correctness issues with dynamic shapes
# inputs, as XLA will just assume the inputs are with padded shapes. However
# users can optionally set it to False to improve device time if masking is
# already handled in the user side.
# If None, will respect XLA default.
tpu_enable_xla_dynamic_padder: Optional[bool] = None
# Global model parallelism configurations.
num_cores_per_replica: int = 1
default_shard_dim: int = -1
def model_parallelism(self):
return dict(
num_cores_per_replica=self.num_cores_per_replica,
default_shard_dim=self.default_shard_dim)
@dataclasses.dataclass
class TrainerConfig(base_config.Config):
"""Configuration for trainer.
Attributes:
optimizer_config: optimizer config, it includes optimizer, learning rate,
and warmup schedule configs.
train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
steps_per_loop: number of steps per loop to report training metrics. This
can also be used to reduce host worker communication in a TPU setup.
summary_interval: number of steps between each summary.
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely. This
is only used continuous_train_and_eval and continuous_eval modes. Default
value is 1 hrs.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
best_checkpoint_export_subdir: if set, the trainer will keep track of the
best evaluation metric, and export the corresponding best checkpoint under
`model_dir/best_checkpoint_export_subdir`. Note that this only works if
mode contains eval (such as `train_and_eval`, `continuous_eval`, and
`continuous_train_and_eval`).
best_checkpoint_eval_metric: for exporting the best checkpoint, which
evaluation metric the trainer should monitor. This can be any evaluation
metric appears on tensorboard.
best_checkpoint_metric_comp: for exporting the best checkpoint, how the
trainer should compare the evaluation metrics. This can be either `higher`
(higher the better) or `lower` (lower the better).
validation_summary_subdir: A 'str', sub directory for saving eval summary.
"""
optimizer_config: OptimizationConfig = OptimizationConfig()
# Orbit settings.
train_tf_while_loop: bool = True
train_tf_function: bool = True
eval_tf_function: bool = True
eval_tf_while_loop: bool = False
allow_tpu_summary: bool = False
# Trainer intervals.
steps_per_loop: int = 1000
summary_interval: int = 1000
checkpoint_interval: int = 1000
# Checkpoint manager.
max_to_keep: int = 5
continuous_eval_timeout: int = 60 * 60
# Train/Eval routines.
train_steps: int = 0
# Sets validation steps to be -1 to evaluate the entire dataset.
validation_steps: int = -1
validation_interval: int = 1000
# Best checkpoint export.
best_checkpoint_export_subdir: str = ""
best_checkpoint_eval_metric: str = ""
best_checkpoint_metric_comp: str = "higher"
# Blowup recovery.
loss_upper_bound: float = 1e6
recovery_begin_steps: int = 0 # Enforcing the loss bound after these steps.
# When max trials < 0, no recovery module; max trials = 0, we will check
# the condition and fail the job if the condition happens; max trials > 0,
# we will retore the model states.
recovery_max_trials: int = 0
validation_summary_subdir: str = "validation"
@dataclasses.dataclass
class TaskConfig(base_config.Config):
init_checkpoint: str = ""
model: Optional[base_config.Config] = None
train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig()
name: Optional[str] = None
@dataclasses.dataclass
class ExperimentConfig(base_config.Config):
"""Top-level configuration."""
task: TaskConfig = TaskConfig()
trainer: TrainerConfig = TrainerConfig()
runtime: RuntimeConfig = RuntimeConfig()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experiment factory methods."""
from official.core import config_definitions as cfg
from official.core import registry
_REGISTERED_CONFIGS = {}
def register_config_factory(name):
"""Register ExperimentConfig factory method."""
return registry.register(_REGISTERED_CONFIGS, name)
def get_exp_config(exp_name: str) -> cfg.ExperimentConfig:
"""Looks up the `ExperimentConfig` according to the `exp_name`."""
exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
return exp_creater()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for model export."""
import abc
import functools
from typing import Any, Callable, Dict, Mapping, List, Optional, Text, Union
import tensorflow as tf
from tensorflow.python.saved_model.model_utils import export_utils
class ExportModule(tf.Module, metaclass=abc.ABCMeta):
"""Base Export Module."""
def __init__(self,
params,
model: Union[tf.Module, tf.keras.Model],
inference_step: Optional[Callable[..., Any]] = None,
*,
preprocessor: Optional[Callable[..., Any]] = None,
postprocessor: Optional[Callable[..., Any]] = None):
"""Instantiates an ExportModel.
Examples:
`inference_step` must be a function that has `model` as an kwarg or the
second positional argument.
```
def _inference_step(inputs, model=None):
return model(inputs, training=False)
module = ExportModule(params, model, inference_step=_inference_step)
```
`preprocessor` and `postprocessor` could be either functions or `tf.Module`.
The usages of preprocessor and postprocessor are managed by the
implementation of `serve()` method.
Args:
params: A dataclass for parameters to the module.
model: A model instance which contains weights and forward computation.
inference_step: An optional callable to forward-pass the model. If not
specified, it creates a parital function with `model` as an required
kwarg.
preprocessor: An optional callable to preprocess the inputs.
postprocessor: An optional callable to postprocess the model outputs.
"""
super().__init__(name=None)
self.model = model
self.params = params
if inference_step is not None:
self.inference_step = functools.partial(inference_step, model=self.model)
else:
self.inference_step = functools.partial(
self.model.__call__, training=False)
self.preprocessor = preprocessor
self.postprocessor = postprocessor
@abc.abstractmethod
def serve(self) -> Mapping[Text, tf.Tensor]:
"""The bare inference function which should run on all devices.
Expecting tensors are passed in through keyword arguments. Returns a
dictionary of tensors, when the keys will be used inside the SignatureDef.
"""
@abc.abstractmethod
def get_inference_signatures(
self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
"""Get defined function signatures."""
def export(export_module: ExportModule,
function_keys: Union[List[Text], Dict[Text, Text]],
export_savedmodel_dir: Text,
checkpoint_path: Optional[Text] = None,
timestamped: bool = True,
save_options: Optional[tf.saved_model.SaveOptions] = None) -> Text:
"""Exports to SavedModel format.
Args:
export_module: a ExportModule with the keras Model and serving tf.functions.
function_keys: a list of string keys to retrieve pre-defined serving
signatures. The signaute keys will be set with defaults. If a dictionary
is provided, the values will be used as signature keys.
export_savedmodel_dir: Output saved model directory.
checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory.
save_options: `SaveOptions` for `tf.saved_model.save`.
Returns:
The savedmodel directory path.
"""
ckpt_dir_or_file = checkpoint_path
if ckpt_dir_or_file is not None and tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if ckpt_dir_or_file:
checkpoint = tf.train.Checkpoint(model=export_module.model)
checkpoint.read(
ckpt_dir_or_file).assert_existing_objects_matched().expect_partial()
if isinstance(function_keys, list):
if len(function_keys) == 1:
function_keys = {
function_keys[0]: tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
}
else:
raise ValueError(
"If the function_keys is a list, it must contain a single element. %s"
% function_keys)
signatures = export_module.get_inference_signatures(function_keys)
if timestamped:
export_dir = export_utils.get_timestamped_export_dir(
export_savedmodel_dir).decode("utf-8")
else:
export_dir = export_savedmodel_dir
tf.saved_model.save(
export_module, export_dir, signatures=signatures, options=save_options)
return export_dir
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.export_base."""
import os
from typing import Any, Dict, Mapping, Text
import tensorflow as tf
from official.core import export_base
class TestModule(export_base.ExportModule):
@tf.function
def serve(self, inputs: tf.Tensor) -> Mapping[Text, tf.Tensor]:
x = inputs if self.preprocessor is None else self.preprocessor(
inputs=inputs)
x = self.inference_step(x)
x = self.postprocessor(x) if self.postprocessor else x
return {'outputs': x}
def get_inference_signatures(
self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
input_signature = tf.TensorSpec(shape=[None, None], dtype=tf.float32)
return {'foo': self.serve.get_concrete_function(input_signature)}
class ExportBaseTest(tf.test.TestCase):
def test_export_module(self):
tmp_dir = self.get_temp_dir()
model = tf.keras.layers.Dense(2)
inputs = tf.ones([2, 4], tf.float32)
expected_output = model(inputs, training=False)
module = TestModule(params=None, model=model)
ckpt_path = tf.train.Checkpoint(model=model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, ['foo'],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=True)
self.assertTrue(os.path.exists(os.path.join(export_dir, 'saved_model.pb')))
self.assertTrue(
os.path.exists(
os.path.join(export_dir, 'variables', 'variables.index')))
self.assertTrue(
os.path.exists(
os.path.join(export_dir, 'variables',
'variables.data-00000-of-00001')))
imported = tf.saved_model.load(export_dir)
output = imported.signatures['foo'](inputs)
self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
def test_custom_inference_step(self):
tmp_dir = self.get_temp_dir()
model = tf.keras.layers.Dense(2)
inputs = tf.ones([2, 4], tf.float32)
def _inference_step(inputs, model):
return tf.nn.softmax(model(inputs, training=False))
module = TestModule(
params=None, model=model, inference_step=_inference_step)
expected_output = _inference_step(inputs, model)
ckpt_path = tf.train.Checkpoint(model=model).save(
os.path.join(tmp_dir, 'ckpt'))
export_dir = export_base.export(
module, ['foo'],
export_savedmodel_dir=tmp_dir,
checkpoint_path=ckpt_path,
timestamped=False)
imported = tf.saved_model.load(export_dir)
output = imported.signatures['foo'](inputs)
self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
def test_processors(self):
model = tf.Module()
inputs = tf.zeros((), tf.float32)
def _inference_step(inputs, model):
del model
return inputs + 1.0
def _preprocessor(inputs):
print(inputs)
return inputs + 0.1
module = TestModule(
params=None,
model=model,
inference_step=_inference_step,
preprocessor=_preprocessor)
output = module.serve(inputs)
self.assertAllClose(output['outputs'].numpy(), 1.1)
class _PostProcessor(tf.Module):
def __call__(self, inputs):
return inputs + 0.01
module = TestModule(
params=None,
model=model,
inference_step=_inference_step,
preprocessor=_preprocessor,
postprocessor=_PostProcessor())
output = module.serve(inputs)
self.assertAllClose(output['outputs'].numpy(), 1.11)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A common dataset reader."""
import random
from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Union
from absl import logging
import tensorflow as tf
import tensorflow_datasets as tfds
from official.core import config_definitions as cfg
def _get_random_integer():
return random.randint(0, (1 << 31) - 1)
def _maybe_map_fn(dataset: tf.data.Dataset,
fn: Optional[Callable[..., Any]] = None) -> tf.data.Dataset:
"""Calls dataset.map if a valid function is passed in."""
return dataset if fn is None else dataset.map(
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def match_files(input_path: Union[Sequence[str], str]) -> List[str]:
"""Matches files from an input_path."""
matched_files = []
# Read dataset from files.
usage = ('`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s')
if isinstance(input_path, str):
input_path_list = [input_path]
elif isinstance(input_path, (list, tuple)):
if any(not isinstance(x, str) for x in input_path):
raise ValueError(usage % input_path)
input_path_list = input_path
else:
raise ValueError(usage % input_path)
for input_path in input_path_list:
input_patterns = input_path.strip().split(',')
for input_pattern in input_patterns:
input_pattern = input_pattern.strip()
if not input_pattern:
continue
if '*' in input_pattern or '?' in input_pattern:
tmp_matched_files = tf.io.gfile.glob(input_pattern)
if not tmp_matched_files:
raise ValueError('%s does not match any files.' % input_pattern)
matched_files.extend(tmp_matched_files)
else:
matched_files.append(input_pattern)
if not matched_files:
raise ValueError('%s does not match any files.' % input_path)
return matched_files
def _read_files_then_shard(matched_files: List[str],
dataset_fn,
input_context: Optional[
tf.distribute.InputContext] = None,
sharding: bool = False,
repeat: bool = False) -> tf.data.Dataset:
"""Sends all data files to every worker and then shard by data."""
dataset = dataset_fn(matched_files)
# When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding
# so that same input file is sent to all workers.
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
dataset = dataset.with_options(options)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if sharding and input_context and (input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if repeat:
dataset = dataset.repeat()
return dataset
def _shard_files_then_read(matched_files: List[str],
dataset_fn,
input_context: Optional[
tf.distribute.InputContext] = None,
seed: Optional[Union[int, tf.Tensor]] = None,
is_training: bool = False,
sharding: bool = False,
cache: bool = False,
cycle_length: Optional[int] = None,
block_length: Optional[int] = None,
deterministic: bool = False) -> tf.data.Dataset:
"""Shards the data files and then sent a split to every worker to read."""
dataset = tf.data.Dataset.from_tensor_slices(matched_files)
# Shuffle and repeat at file level.
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway.
if is_training:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if sharding and seed is None:
seed = _get_random_integer()
dataset = dataset.shuffle(
len(matched_files),
seed=seed,
reshuffle_each_iteration=True if not cache else False)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if sharding and input_context and (input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if is_training and not cache:
dataset = dataset.repeat()
dataset = dataset.interleave(
map_func=dataset_fn,
cycle_length=cycle_length,
block_length=block_length,
num_parallel_calls=(cycle_length
if cycle_length else tf.data.experimental.AUTOTUNE),
deterministic=deterministic)
return dataset
def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
tfds_split: Text,
tfds_skip_decoding_feature: Text,
tfds_as_supervised: bool,
input_context: Optional[tf.distribute.InputContext] = None,
seed: Optional[Union[int, tf.Tensor]] = None,
is_training: bool = False,
cache: bool = False,
cycle_length: Optional[int] = None,
block_length: Optional[int] = None) -> tf.data.Dataset:
"""Reads a dataset from tfds."""
# No op if exist.
tfds_builder.download_and_prepare()
read_config = tfds.ReadConfig(
interleave_cycle_length=cycle_length,
interleave_block_length=block_length,
input_context=input_context,
shuffle_seed=seed)
decoders = {}
if tfds_skip_decoding_feature:
for skip_feature in tfds_skip_decoding_feature.split(','):
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
dataset = tfds_builder.as_dataset(
split=tfds_split,
shuffle_files=is_training,
as_supervised=tfds_as_supervised,
decoders=decoders,
read_config=read_config)
if is_training and not cache:
dataset = dataset.repeat()
return dataset
class InputReader:
"""Input reader that returns a tf.data.Dataset instance."""
# A static random number which is the same across different InputReader
# instances.
static_randnum = _get_random_integer()
def __init__(self,
params: cfg.DataConfig,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None,
combine_fn: Optional[Callable[..., Any]] = None,
sample_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None,
transform_and_batch_fn: Optional[Callable[
[tf.data.Dataset, Optional[tf.distribute.InputContext]],
tf.data.Dataset]] = None,
postprocess_fn: Optional[Callable[..., Any]] = None):
"""Initializes an InputReader instance.
Args:
params: A config_definitions.DataConfig object.
dataset_fn: A `tf.data.Dataset` that consumes the input files. For
example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary.
combine_fn: An optional `callable` that takes a dictionarty of
`tf.data.Dataset` objects as input and outputs a combined dataset. It
will be executed after the decoder_fn and before the sample_fn.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn.
parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn.
transform_and_batch_fn: An optional `callable` that takes a
`tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
input, and returns a `tf.data.Dataset` object. It will be executed after
`parser_fn` to transform and batch the dataset; if None, after
`parser_fn` is executed, the dataset will be batched into per-replica
batch size.
postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching.
"""
if params.input_path and params.tfds_name:
raise ValueError('At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.' %
(params.input_path, params.tfds_name))
if isinstance(params.input_path,
cfg.base_config.Config) and combine_fn is None:
raise ValueError(
'A `combine_fn` is required if the `input_path` is a dictionary.')
self._tfds_builder = None
self._matched_files = None
if not params.input_path:
# Read dataset from TFDS.
if not params.tfds_split:
raise ValueError(
'`tfds_name` is %s, but `tfds_split` is not specified.' %
params.tfds_name)
self._tfds_builder = tfds.builder(
params.tfds_name, data_dir=params.tfds_data_dir)
else:
self._matched_files = self.get_files(params.input_path)
self._global_batch_size = params.global_batch_size
self._is_training = params.is_training
self._drop_remainder = params.drop_remainder
self._shuffle_buffer_size = params.shuffle_buffer_size
self._cache = params.cache
self._cycle_length = params.cycle_length
self._block_length = params.block_length
self._deterministic = params.deterministic
self._sharding = params.sharding
self._tfds_split = params.tfds_split
self._tfds_as_supervised = params.tfds_as_supervised
self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature
self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn
self._combine_fn = combine_fn
self._sample_fn = sample_fn
self._parser_fn = parser_fn
self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn
self._seed = params.seed
# When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None.
# Sharding should also be disabled because tf data service handles how
# each worker shard data with `processing_mode` in distribute method.
if params.enable_tf_data_service:
self._seed = None
self._sharding = False
self._enable_tf_data_service = (
params.enable_tf_data_service and params.tf_data_service_address)
self._tf_data_service_address = params.tf_data_service_address
if self._enable_tf_data_service:
# Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted.
self._tf_data_service_job_name = (
params.tf_data_service_job_name + str(self.static_randnum))
self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False)
@property
def tfds_info(self) -> tfds.core.DatasetInfo:
"""Returns TFDS dataset info, if available."""
if self._tfds_builder:
return self._tfds_builder.info
else:
raise ValueError('tfds_info is not available, because the dataset '
'is not loaded from tfds.')
def get_files(self, input_path):
"""Gets matched files. Can be overridden by subclasses."""
if not input_path:
return None
# we want to combine / mix datasets
if isinstance(input_path, cfg.base_config.Config):
matched_files = {}
for k, v in input_path.as_dict().items():
matched_files[k] = match_files(v)
# single dataset
else:
matched_files = match_files(input_path)
return matched_files
def _read_data_source(
self,
matched_files: Union[Dict[str, List[str]], List[str]],
dataset_fn,
input_context: Optional[tf.distribute.InputContext] = None,
tfds_builder: Optional[tfds.core.DatasetBuilder] = None):
"""Reads the data source (files/tfds) to a dataset."""
def _files_to_dataset(files: List[str]) -> tf.data.Dataset:
if len(files) > 1:
if input_context and (len(files) < input_context.num_input_pipelines):
logging.warn(
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.', len(files),
input_context.num_input_pipelines)
return _read_files_then_shard(
files,
dataset_fn,
input_context,
sharding=self._sharding,
repeat=self._is_training and not self._cache)
else:
return _shard_files_then_read(
files,
dataset_fn,
input_context,
seed=self._seed,
is_training=self._is_training,
sharding=self._sharding,
cache=self._cache,
cycle_length=self._cycle_length,
block_length=self._block_length,
deterministic=self._deterministic)
elif len(files) == 1:
return _read_files_then_shard(
files,
dataset_fn,
input_context,
sharding=self._sharding,
repeat=self._is_training and not self._cache)
else:
raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `files`.')
if tfds_builder:
dataset = _read_tfds(
tfds_builder=self._tfds_builder,
tfds_split=self._tfds_split,
tfds_skip_decoding_feature=self._tfds_skip_decoding_feature,
tfds_as_supervised=self._tfds_as_supervised,
input_context=input_context,
seed=self._seed,
is_training=self._is_training,
cache=self._cache,
cycle_length=self._cycle_length,
block_length=self._block_length)
elif isinstance(matched_files, (list, tuple)):
dataset = _files_to_dataset(matched_files)
elif isinstance(matched_files, dict):
dataset = {}
for k, fs in matched_files.items():
dataset[k] = _files_to_dataset(fs)
else:
raise ValueError('`matched_files` should be a list or dict.')
return dataset
def _decode_and_parse_dataset(
self,
dataset: Union[tf.data.Dataset, Dict[Text, tf.data.Dataset]],
batch_size: int,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Returns a tf.data.Dataset object after shuffling, decoding, and parsing."""
def _shuffle_and_decode(ds):
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if self._is_training and not self._cache:
ds = ds.shuffle(self._shuffle_buffer_size, seed=self._seed)
# Decode
ds = _maybe_map_fn(ds, self._decoder_fn)
return ds
dataset = tf.nest.map_structure(_shuffle_and_decode, dataset)
if tf.nest.is_nested(dataset):
dataset = self._combine_fn(dataset)
if self._sample_fn is not None:
dataset = dataset.apply(self._sample_fn)
dataset = _maybe_map_fn(dataset, self._parser_fn)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.repeat()
dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context)
else:
per_replica_batch_size = input_context.get_per_replica_batch_size(
batch_size) if input_context else batch_size
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder)
return dataset
def _maybe_apply_data_service(
self,
dataset: tf.data.Dataset,
input_context: Optional[tf.distribute.InputContext] = None
) -> tf.data.Dataset:
"""Potentially distributes a dataset."""
if self._enable_tf_data_service and input_context:
if self._enable_round_robin_tf_data_service:
replicas_per_input_pipeline = input_context.num_replicas_in_sync // (
input_context.num_input_pipelines)
base_consumer_index = input_context.input_pipeline_id * (
replicas_per_input_pipeline)
num_consumers = input_context.num_input_pipelines * (
replicas_per_input_pipeline)
range_dataset = tf.data.Dataset.range(replicas_per_input_pipeline)
dataset = range_dataset.map(lambda i: dataset.apply( # pylint: disable=g-long-lambda
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name,
consumer_index=base_consumer_index + i,
num_consumers=num_consumers)))
# Use parallel interleave to read multiple batches from a tf.data
# service worker in parallel.
dataset = dataset.interleave(
lambda x: x,
cycle_length=replicas_per_input_pipeline,
num_parallel_calls=replicas_per_input_pipeline,
deterministic=True)
else:
dataset = dataset.apply(
tf.data.experimental.service.distribute(
processing_mode='parallel_epochs',
service=self._tf_data_service_address,
job_name=self._tf_data_service_job_name))
return dataset
def read(self,
input_context: Optional[tf.distribute.InputContext] = None,
dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
"""Generates a tf.data.Dataset object."""
if dataset is None:
dataset = self._read_data_source(
self._matched_files, self._dataset_fn, input_context,
self._tfds_builder)
dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
input_context)
dataset = _maybe_map_fn(dataset, self._postprocess_fn)
dataset = self._maybe_apply_data_service(dataset, input_context)
if self._deterministic is not None:
options = tf.data.Options()
options.experimental_deterministic = self._deterministic
dataset = dataset.with_options(options)
return dataset.prefetch(tf.data.experimental.AUTOTUNE)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Registry utility."""
def register(registered_collection, reg_key):
"""Register decorated function or class to collection.
Register decorated function or class into registered_collection, in a
hierarchical order. For example, when reg_key="my_model/my_exp/my_config_0"
the decorated function or class is stored under
registered_collection["my_model"]["my_exp"]["my_config_0"].
This decorator is supposed to be used together with the lookup() function in
this file.
Args:
registered_collection: a dictionary. The decorated function or class will be
put into this collection.
reg_key: The key for retrieving the registered function or class. If reg_key
is a string, it can be hierarchical like my_model/my_exp/my_config_0
Returns:
A decorator function
Raises:
KeyError: when function or class to register already exists.
"""
def decorator(fn_or_cls):
"""Put fn_or_cls in the dictionary."""
if isinstance(reg_key, str):
hierarchy = reg_key.split("/")
collection = registered_collection
for h_idx, entry_name in enumerate(hierarchy[:-1]):
if entry_name not in collection:
collection[entry_name] = {}
collection = collection[entry_name]
if not isinstance(collection, dict):
raise KeyError(
"Collection path {} at position {} already registered as "
"a function or class.".format(entry_name, h_idx))
leaf_reg_key = hierarchy[-1]
else:
collection = registered_collection
leaf_reg_key = reg_key
if leaf_reg_key in collection:
raise KeyError("Function or class {} registered multiple times.".format(
leaf_reg_key))
collection[leaf_reg_key] = fn_or_cls
return fn_or_cls
return decorator
def lookup(registered_collection, reg_key):
"""Lookup and return decorated function or class in the collection.
Lookup decorated function or class in registered_collection, in a
hierarchical order. For example, when
reg_key="my_model/my_exp/my_config_0",
this function will return
registered_collection["my_model"]["my_exp"]["my_config_0"].
Args:
registered_collection: a dictionary. The decorated function or class will be
retrieved from this collection.
reg_key: The key for retrieving the registered function or class. If reg_key
is a string, it can be hierarchical like my_model/my_exp/my_config_0
Returns:
The registered function or class.
Raises:
LookupError: when reg_key cannot be found.
"""
if isinstance(reg_key, str):
hierarchy = reg_key.split("/")
collection = registered_collection
for h_idx, entry_name in enumerate(hierarchy):
if entry_name not in collection:
raise LookupError(
f"collection path {entry_name} at position {h_idx} is never "
f"registered. Please make sure the {entry_name} and its library is "
"imported and linked to the trainer binary.")
collection = collection[entry_name]
return collection
else:
if reg_key not in registered_collection:
raise LookupError(
f"registration key {reg_key} is never "
f"registered. Please make sure the {reg_key} and its library is "
"imported and linked to the trainer binary.")
return registered_collection[reg_key]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for registry."""
import tensorflow as tf
from official.core import registry
class RegistryTest(tf.test.TestCase):
def test_register(self):
collection = {}
@registry.register(collection, 'functions/func_0')
def func_test():
pass
self.assertEqual(registry.lookup(collection, 'functions/func_0'), func_test)
@registry.register(collection, 'classes/cls_0')
class ClassRegistryKey:
pass
self.assertEqual(
registry.lookup(collection, 'classes/cls_0'), ClassRegistryKey)
@registry.register(collection, ClassRegistryKey)
class ClassRegistryValue:
pass
self.assertEqual(
registry.lookup(collection, ClassRegistryKey), ClassRegistryValue)
def test_register_hierarchy(self):
collection = {}
@registry.register(collection, 'functions/func_0')
def func_test0():
pass
@registry.register(collection, 'func_1')
def func_test1():
pass
@registry.register(collection, func_test1)
def func_test2():
pass
expected_collection = {
'functions': {
'func_0': func_test0,
},
'func_1': func_test1,
func_test1: func_test2,
}
self.assertEqual(collection, expected_collection)
def test_register_error(self):
collection = {}
@registry.register(collection, 'functions/func_0')
def func_test0(): # pylint: disable=unused-variable
pass
with self.assertRaises(KeyError):
@registry.register(collection, 'functions/func_0/sub_func')
def func_test1(): # pylint: disable=unused-variable
pass
with self.assertRaises(LookupError):
registry.lookup(collection, 'non-exist')
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A global factory to register and access all registered tasks."""
from official.core import registry
_REGISTERED_TASK_CLS = {}
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def register_task_cls(task_config_cls):
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
This decorator supports registration of tasks as follows:
```
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
def get_task(task_config, **kwargs):
"""Creates a Task (of suitable subclass type) from task_config."""
# TODO(hongkuny): deprecate the task factory to use config.BUILDER.
if task_config.BUILDER is not None:
return task_config.BUILDER(task_config, **kwargs)
return get_task_cls(task_config.__class__)(task_config, **kwargs)
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def get_task_cls(task_config_cls):
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
return task_cls
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for testing."""
import tensorflow as tf
class FakeKerasModel(tf.keras.Model):
"""Fake keras model for testing."""
def __init__(self):
super().__init__()
self.dense = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
def call(self, inputs):
return self.dense2(self.dense(inputs))
class _Dense(tf.Module):
"""A dense layer."""
def __init__(self, input_dim, output_size, name=None):
super().__init__(name=name)
with self.name_scope:
self.w = tf.Variable(
tf.random.normal([input_dim, output_size]), name='w')
self.b = tf.Variable(tf.zeros([output_size]), name='b')
@tf.Module.with_name_scope
def __call__(self, x):
y = tf.matmul(x, self.w) + self.b
return tf.nn.relu(y)
class FakeModule(tf.Module):
"""Fake model using tf.Module for testing."""
def __init__(self, input_size, name=None):
super().__init__(name=name)
with self.name_scope:
self.dense = _Dense(input_size, 4, name='dense')
self.dense2 = _Dense(4, 4, name='dense_1')
@tf.Module.with_name_scope
def __call__(self, x):
return self.dense2(self.dense(x))
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM common training driver library."""
# pytype: disable=attribute-error
import os
from typing import Any, Mapping, Optional, Tuple
# Import libraries
from absl import logging
import orbit
import tensorflow as tf
from official.core import actions
from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions
from official.core import train_utils
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
def run_experiment(
distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
with distribution_strategy.scope():
if not trainer:
trainer = train_utils.create_trainer(
params,
task,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(
params, model_dir))
if trainer.checkpoint:
if model_dir is None:
raise ValueError('model_dir must be specified, but got None')
checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=trainer.global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize)
else:
checkpoint_manager = None
controller = controller_cls(
strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None,
evaluator=trainer,
global_step=trainer.global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None,
eval_summary_dir=os.path.join(model_dir,
params.trainer.validation_summary_subdir) if
(save_summary) else None,
summary_interval=params.trainer.summary_interval if
(save_summary) else None,
train_actions=actions.get_train_actions(
params, trainer, model_dir, checkpoint_manager=checkpoint_manager),
eval_actions=actions.get_eval_actions(params, trainer, model_dir))
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if trainer.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
num_params = train_utils.try_count_params(trainer.model)
if num_params is not None:
logging.info('Number of trainable params in model: %f Millions.',
num_params / 10.**6)
flops = train_utils.try_count_flops(trainer.model)
if flops is not None:
logging.info('FLOPs (multi-adds) in model: %f Billions.',
flops / 10.**9 / 2)
if run_post_eval:
with distribution_strategy.scope():
return trainer.model, trainer.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps))
else:
return trainer.model, {}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for train_ctl_lib."""
import json
import os
from absl import flags
from absl.testing import flagsaver
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.common import flags as tfm_flags
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.utils.testing import mock_task
FLAGS = flags.FLAGS
tfm_flags.define_flags()
class TrainTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainTest, self).setUp()
self._test_config = {
'trainer': {
'checkpoint_interval': 10,
'steps_per_loop': 10,
'summary_interval': 10,
'train_steps': 10,
'validation_steps': 5,
'validation_interval': 10,
'continuous_eval_timeout': 1,
'validation_summary_subdir': 'validation',
'optimizer_config': {
'optimizer': {
'type': 'sgd',
},
'learning_rate': {
'type': 'constant'
}
}
},
}
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'eval', 'train_and_eval'],
run_post_eval=[True, False]))
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval)
if 'eval' in flag_mode:
self.assertTrue(
tf.io.gfile.exists(
os.path.join(model_dir,
params.trainer.validation_summary_subdir)))
if run_post_eval:
self.assertNotEmpty(logs)
else:
self.assertEmpty(logs)
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
if flag_mode == 'eval':
return
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
# Tests continuous evaluation.
_, logs = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='continuous_eval',
params=params,
model_dir=model_dir,
run_post_eval=run_post_eval)
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train', 'train_and_eval'],
))
def test_recovery_nan_error(self, distribution_strategy, flag_mode):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
# task = task_factory.get_task(params.task, logging_dir=model_dir)
task = mock_task.MockTask(params.task, logging_dir=model_dir)
# Set the loss to NaN to trigger RunTimeError.
def build_losses(labels, model_outputs, aux_losses=None):
del labels, model_outputs
return tf.constant([np.nan], tf.float32) + aux_losses
task.build_losses = build_losses
with self.assertRaises(RuntimeError):
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir)
@combinations.generate(
combinations.combine(
distribution_strategy=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],
flag_mode=['train'],
))
def test_recovery(self, distribution_strategy, flag_mode):
loss_threshold = 1.0
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode=flag_mode,
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS)
params.trainer.loss_upper_bound = loss_threshold
params.trainer.recovery_max_trials = 1
train_utils.serialize_config(params, model_dir)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
# Saves a checkpoint for reference.
model = task.build_model()
checkpoint = tf.train.Checkpoint(model=model)
checkpoint_manager = tf.train.CheckpointManager(
checkpoint, self.get_temp_dir(), max_to_keep=2)
checkpoint_manager.save()
before_weights = model.get_weights()
def build_losses(labels, model_outputs, aux_losses=None):
del labels, model_outputs
return tf.constant([loss_threshold], tf.float32) + aux_losses
task.build_losses = build_losses
model, _ = train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=flag_mode,
params=params,
model_dir=model_dir)
after_weights = model.get_weights()
for left, right in zip(before_weights, after_weights):
self.assertAllEqual(left, right)
def test_parse_configuration(self):
model_dir = self.get_temp_dir()
flags_dict = dict(
experiment='mock',
mode='train',
model_dir=model_dir,
params_override=json.dumps(self._test_config))
with flagsaver.flagsaver(**flags_dict):
params = train_utils.parse_configuration(flags.FLAGS, lock_return=True)
with self.assertRaises(ValueError):
params.override({'task': {'init_checkpoint': 'Foo'}})
params = train_utils.parse_configuration(flags.FLAGS, lock_return=False)
params.override({'task': {'init_checkpoint': 'Bar'}})
self.assertEqual(params.task.init_checkpoint, 'Bar')
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training utils."""
import copy
import json
import os
import pprint
from typing import Any, Callable, Dict, List, Optional, Union
from absl import logging
import dataclasses
import gin
import orbit
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions
from official.core import exp_factory
from official.modeling import hyperparams
def get_leaf_nested_dict(d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
"""Get leaf from a dictionary with arbitrary depth with a list of keys.
Args:
d: The dictionary to extract value from.
keys: The list of keys to extract values recursively.
Returns:
The value of the leaf.
Raises:
KeyError: If the value of keys extracted is a dictionary.
"""
leaf = d
for k in keys:
if not isinstance(leaf, dict) or k not in leaf:
raise KeyError(
'Path not exist while traversing the dictionary: d with keys'
': %s.' % keys)
leaf = leaf[k]
if isinstance(leaf, dict):
raise KeyError('The value extracted with keys: %s is not a leaf of the '
'dictionary: %s.' % (keys, d))
return leaf
def cast_leaf_nested_dict(d: Dict[str, Any],
cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
"""Cast the leaves of a dictionary with arbitrary depth in place.
Args:
d: The dictionary to extract value from.
cast_fn: The casting function.
Returns:
A dictionray with the same structure as d.
"""
for key, value in d.items():
if isinstance(value, dict):
d[key] = cast_leaf_nested_dict(value, cast_fn)
else:
d[key] = cast_fn(value)
return d
def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
data_dir: str) -> Any:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir = params.trainer.best_checkpoint_export_subdir
metric_name = params.trainer.best_checkpoint_eval_metric
metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(best_ckpt_dir, metric_name,
metric_comp)
logging.info(
'Created the best checkpoint exporter. '
'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
export_subdir, metric_name)
else:
best_ckpt_exporter = None
return best_ckpt_exporter
# TODO(b/180147589): Add tests for this module.
class BestCheckpointExporter:
"""Keeps track of the best result, and saves its checkpoint.
Orbit will support an API for checkpoint exporter. This class will be used
together with orbit once this functionality is ready.
"""
def __init__(self, export_dir: str, metric_name: str, metric_comp: str):
"""Initialization.
Args:
export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which
result is better. If eval_logs being passed to maybe_export_checkpoint
is a nested dictionary, use `|` as a seperator for different layers.
metric_comp: Indicates how to compare results. Either `lower` or `higher`.
"""
self._export_dir = export_dir
self._metric_name = metric_name.split('|')
self._metric_comp = metric_comp
if self._metric_comp not in ('lower', 'higher'):
raise ValueError('best checkpoint metric comp must be one of '
'higher, lower. Got: {}'.format(self._metric_comp))
tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
self._best_ckpt_logs = self._maybe_load_best_eval_metric()
self._checkpoint_manager = None
def _get_checkpoint_manager(self, checkpoint):
"""Gets an existing checkpoint manager or creates a new one."""
if self._checkpoint_manager is None or (self._checkpoint_manager.checkpoint
!= checkpoint):
logging.info('Creates a new checkpoint manager.')
self._checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=self._export_dir,
max_to_keep=1,
checkpoint_name='best_ckpt')
return self._checkpoint_manager
def maybe_export_checkpoint(
self, checkpoint, eval_logs, global_step, write_logs=True) -> bool:
"""Compare eval_logs with past eval_logs and export checkpoint if better."""
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
eval_logs, global_step)
if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs
if write_logs:
self.export_best_eval_metric(self._best_ckpt_logs, global_step)
self._get_checkpoint_manager(checkpoint).save()
return True
return False
def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path):
return None
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'r') as reader:
return json.loads(reader.read())
def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs."""
old_value = float(
orbit.utils.get_value(
get_leaf_nested_dict(old_logs, self._metric_name)))
new_value = float(
orbit.utils.get_value(
get_leaf_nested_dict(new_logs, self._metric_name)))
logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
old_value, new_value)
if self._metric_comp == 'higher':
if new_value > old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is higher.')
return True
else: # self._metric_comp == 'lower':
if new_value < old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is lower.')
return True
return False
def export_best_eval_metric(self, eval_logs, global_step):
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext = copy.copy(eval_logs)
eval_logs_ext['best_ckpt_global_step'] = global_step
eval_logs_ext = cast_leaf_nested_dict(
eval_logs_ext, lambda x: float(orbit.utils.get_value(x)))
# Saving json file is very fast.
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
@property
def best_ckpt_logs(self):
return self._best_ckpt_logs
@property
def best_ckpt_logs_path(self):
return os.path.join(self._export_dir, 'info.json')
@property
def best_ckpt_path(self):
"""Returns the best ckpt path or None if there is no ckpt yet."""
return tf.train.latest_checkpoint(self._export_dir)
@gin.configurable
def create_trainer(params: config_definitions.ExperimentConfig,
task: base_task.Task,
train: bool,
evaluate: bool,
checkpoint_exporter: Optional[BestCheckpointExporter] = None,
trainer_cls=base_trainer.Trainer) -> base_trainer.Trainer:
"""Create trainer."""
logging.info('Running default trainer.')
model = task.build_model()
optimizer = task.create_optimizer(params.trainer.optimizer_config,
params.runtime)
return trainer_cls(
params,
task,
model=model,
optimizer=optimizer,
train=train,
evaluate=evaluate,
checkpoint_exporter=checkpoint_exporter)
@dataclasses.dataclass
class ParseConfigOptions:
"""Use this dataclass instead of FLAGS to customize parse_configuration()."""
experiment: str
config_file: List[str]
tpu: str = ''
tf_data_service: str = ''
params_override: str = ''
def __contains__(self, name):
return name in dataclasses.asdict(self)
def parse_configuration(flags_obj, lock_return=True, print_return=True):
"""Parses ExperimentConfig from flags."""
if flags_obj.experiment is None:
raise ValueError('The flag --experiment must be specified.')
# 1. Get the default config from the registered experiment.
params = exp_factory.get_exp_config(flags_obj.experiment)
# 2. Get the first level of override from `--config_file`.
# `--config_file` is typically used as a template that specifies the common
# override for a particular experiment.
for config_file in flags_obj.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
# 3. Override the TPU address and tf.data service address.
params.override({
'runtime': {
'tpu': flags_obj.tpu,
},
})
if ('tf_data_service' in flags_obj and flags_obj.tf_data_service and
isinstance(params.task, config_definitions.TaskConfig)):
params.override({
'task': {
'train_data': {
'tf_data_service_address': flags_obj.tf_data_service,
},
'validation_data': {
'tf_data_service_address': flags_obj.tf_data_service,
}
}
})
# 4. Get the second level of override from `--params_override`.
# `--params_override` is typically used as a further override over the
# template. For example, one may define a particular template for training
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
# then define different learning rates and pass it via `--params_override`.
if flags_obj.params_override:
params = hyperparams.override_params_dict(
params, flags_obj.params_override, is_strict=True)
params.validate()
if lock_return:
params.lock()
if print_return:
pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters:\n%s',
pp.pformat(params.as_dict()))
return params
def serialize_config(params: config_definitions.ExperimentConfig,
model_dir: str):
"""Serializes and saves the experiment config."""
if model_dir is None:
raise ValueError('model_dir must be specified, but got None')
params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path)
tf.io.gfile.makedirs(model_dir)
hyperparams.save_params_dict_to_yaml(params, params_save_path)
def save_gin_config(filename_suffix: str, model_dir: str):
"""Serializes and saves the experiment config."""
gin_save_path = os.path.join(
model_dir, 'operative_config.{}.gin'.format(filename_suffix))
logging.info('Saving gin configurations to %s', gin_save_path)
tf.io.gfile.makedirs(model_dir)
with tf.io.gfile.GFile(gin_save_path, 'w') as f:
f.write(gin.operative_config_str())
def read_global_step_from_checkpoint(ckpt_file_path):
"""Read global step from checkpoint, or get global step from its filename."""
global_step = tf.Variable(-1, dtype=tf.int64)
ckpt = tf.train.Checkpoint(global_step=global_step)
try:
ckpt.restore(ckpt_file_path).expect_partial()
global_step_maybe_restored = global_step.numpy()
except tf.errors.InvalidArgumentError:
global_step_maybe_restored = -1
if global_step_maybe_restored == -1:
raise ValueError('global_step not found in checkpoint {}. '
'If you want to run finetune eval jobs, you need to '
'make sure that your pretrain model writes '
'global_step in its checkpoints.'.format(ckpt_file_path))
global_step_restored = global_step.numpy()
logging.info('get global_step %d from checkpoint %s', global_step_restored,
ckpt_file_path)
return global_step_restored
def write_json_summary(log_dir, global_step, eval_metrics):
"""Dump evaluation metrics to json file."""
serializable_dict = {}
for name, value in eval_metrics.items():
if hasattr(value, 'numpy'):
serializable_dict[name] = str(value.numpy())
else:
serializable_dict[name] = str(value)
output_json = os.path.join(log_dir, 'metrics-{}.json'.format(global_step))
logging.info('Evaluation results at pretrain step %d: %s', global_step,
serializable_dict)
with tf.io.gfile.GFile(output_json, 'w') as writer:
writer.write(json.dumps(serializable_dict, indent=4) + '\n')
def write_summary(summary_writer, global_step, eval_metrics):
"""Write evaluation metrics to TF summary."""
numeric_dict = {}
for name, value in eval_metrics.items():
numeric_dict[name] = float(orbit.utils.get_value(value))
with summary_writer.as_default():
for name, value in numeric_dict.items():
tf.summary.scalar(name, value, step=global_step)
summary_writer.flush()
def remove_ckpts(model_dir):
"""Remove model checkpoints, so we can restart."""
ckpts = os.path.join(model_dir, 'ckpt-*')
logging.info('removing checkpoint files %s', ckpts)
for file_to_remove in tf.io.gfile.glob(ckpts):
tf.io.gfile.rmtree(file_to_remove)
file_to_remove = os.path.join(model_dir, 'checkpoint')
if tf.io.gfile.exists(file_to_remove):
tf.io.gfile.remove(file_to_remove)
def write_model_params(model: Union[tf.Module, tf.keras.Model],
output_path: str) -> None:
"""Writes the model parameters and shapes to a file.
Args:
model: A model instance.
output_path: Output file path.
"""
with tf.io.gfile.GFile(output_path, 'w') as f:
total_params = 0
for var in model.variables:
shape = tf.shape(var)
total_params += tf.math.reduce_prod(shape).numpy()
f.write(f'{var.name} {shape.numpy().tolist()}\n')
f.write(f'\nTotal params: {total_params}\n')
def try_count_params(
model: Union[tf.Module, tf.keras.Model],
trainable_only: bool = False):
"""Count the number of parameters if model is possible.
Args:
model: Try to count the number of params in this model.
trainable_only: Whether to calculate trainable params only. This flag is
not used when the model has `count_params` attribute.
Returns:
The number of parameters or None.
"""
if hasattr(model, 'count_params'):
try:
return model.count_params()
except ValueError:
logging.info('Number of trainable params unknown, because the build() '
'methods in keras layers were not called. This is probably '
'because the model was not feed any input, e.g., the max '
'train step already reached before this run.')
return None
else:
total_params = 0
variables = model.trainable_variables if trainable_only else model.variables
for var in variables:
shape = tf.shape(var)
total_params += tf.math.reduce_prod(shape).numpy()
return total_params
def try_count_flops(model: Union[tf.Module, tf.keras.Model],
inputs_kwargs: Optional[Dict[str, Any]] = None,
output_path: Optional[str] = None):
"""Counts and returns model FLOPs.
Args:
model: A model instance.
inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
shape specifications to getting corresponding concrete function.
output_path: A file path to write the profiling results to.
Returns:
The model's FLOPs.
"""
if hasattr(model, 'inputs'):
try:
# Get input shape and set batch size to 1.
if model.inputs:
inputs = [
tf.TensorSpec([1] + input.shape[1:], input.dtype)
for input in model.inputs
]
concrete_func = tf.function(model).get_concrete_function(inputs)
# If model.inputs is invalid, try to use the input to get concrete
# function for model.call (subclass model).
else:
concrete_func = tf.function(model.call).get_concrete_function(
**inputs_kwargs)
frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)
# Calculate FLOPs.
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
if output_path is not None:
opts['output'] = f'file:outfile={output_path}'
else:
opts['output'] = 'none'
flops = tf.compat.v1.profiler.profile(
graph=frozen_func.graph, run_meta=run_meta, options=opts)
return flops.total_float_ops
except Exception as e: # pylint: disable=broad-except
logging.info(
'Failed to count model FLOPs with error %s, because the build() '
'methods in keras layers were not called. This is probably because '
'the model was not feed any input, e.g., the max train step already '
'reached before this run.', e)
return None
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment