Commit 78c43ef1 authored by Gunho Park's avatar Gunho Park
Browse files

Merge branch 'master' of https://github.com/tensorflow/models

parents 67cfc95b e3c7e300
...@@ -30,3 +30,18 @@ If you want to contribute, please review the [contribution guidelines](https://g ...@@ -30,3 +30,18 @@ If you want to contribute, please review the [contribution guidelines](https://g
## License ## License
[Apache License 2.0](LICENSE) [Apache License 2.0](LICENSE)
## Citing TensorFlow Model Garden
If you use TensorFlow Model Garden in your research, please cite this repository.
```
@misc{tensorflowmodelgarden2020,
author = {Hongkun Yu and Chen Chen and Xianzhi Du and Yeqing Li and
Abdullah Rashwan and Le Hou and Pengchong Jin and Fan Yang and
Frederick Liu and Jaeyoun Kim and Jing Li},
title = {{TensorFlow Model Garden}},
howpublished = {\url{https://github.com/tensorflow/models}},
year = {2020}
}
```
...@@ -40,7 +40,7 @@ In the near future, we will add: ...@@ -40,7 +40,7 @@ In the near future, we will add:
| Model | Reference (Paper) | | Model | Reference (Paper) |
|-------|-------------------| |-------|-------------------|
| [MNIST](vision/image_classification) | A basic model to classify digits from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) | | [MNIST](vision/image_classification) | A basic model to classify digits from the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) |
| [ResNet](vision/image_classification) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) | | [ResNet](vision/beta/MODEL_GARDEN.md) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
| [ResNet-RS](vision/beta/MODEL_GARDEN.md) | [Revisiting ResNets: Improved Training and Scaling Strategies](https://arxiv.org/abs/2103.07579) | | [ResNet-RS](vision/beta/MODEL_GARDEN.md) | [Revisiting ResNets: Improved Training and Scaling Strategies](https://arxiv.org/abs/2103.07579) |
| [EfficientNet](vision/image_classification) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) | | [EfficientNet](vision/image_classification) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
...@@ -48,10 +48,10 @@ In the near future, we will add: ...@@ -48,10 +48,10 @@ In the near future, we will add:
| Model | Reference (Paper) | | Model | Reference (Paper) |
|-------|-------------------| |-------|-------------------|
| [RetinaNet](vision/detection) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) | | [RetinaNet](vision/beta/MODEL_GARDEN.md) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
| [Mask R-CNN](vision/detection) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) | | [Mask R-CNN](vision/beta/MODEL_GARDEN.md) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
| [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) | | [ShapeMask](vision/detection) | [ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors](https://arxiv.org/abs/1904.03239) |
| [SpineNet](vision/detection) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) | | [SpineNet](vision/beta/MODEL_GARDEN.md) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
### Natural Language Processing ### Natural Language Processing
...@@ -163,17 +163,3 @@ pip3 install tensorflow-text-nightly ...@@ -163,17 +163,3 @@ pip3 install tensorflow-text-nightly
## Contributions ## Contributions
If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute). If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
## Citing TF Official Model Garden
To cite this repository:
```
@software{tfmodels2020github,
author = {Chen Chen and Xianzhi Du and Le Hou and Jaeyoun Kim and Jing Li and
Yeqing Li and Abdullah Rashwan and Fan Yang and Hongkun Yu},
title = {TensorFlow Official Model Garden},
url = {https://github.com/tensorflow/models/tree/master/official},
year = {2020},
}
```
...@@ -102,8 +102,10 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -102,8 +102,10 @@ def get_distribution_strategy(distribution_strategy="mirrored",
distribution_strategy: a string specifying which distribution strategy to distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are "off", "one_device", "mirrored", use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case "parameter_server", "multi_worker_mirrored", and "tpu" -- case
insensitive. "off" means not to use Distribution Strategy; "tpu" means to insensitive. "tpu" means to use TPUStrategy using `tpu_address`.
use TPUStrategy using `tpu_address`. "off" means to use the default strategy which is obtained from
tf.distribute.get_strategy (for details on the default strategy, see
https://www.tensorflow.org/guide/distributed_training#default_strategy).
num_gpus: Number of GPUs to run this model. num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and all-reduce. For `MirroredStrategy`, valid values are "nccl" and
...@@ -141,7 +143,8 @@ def get_distribution_strategy(distribution_strategy="mirrored", ...@@ -141,7 +143,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if num_gpus > 1: if num_gpus > 1:
raise ValueError("When {} GPUs are specified, distribution_strategy " raise ValueError("When {} GPUs are specified, distribution_strategy "
"flag cannot be set to `off`.".format(num_gpus)) "flag cannot be set to `off`.".format(num_gpus))
return None # Return the default distribution strategy.
return tf.distribute.get_strategy()
if distribution_strategy == "tpu": if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs. # When tpu_address is an empty string, we communicate with local TPUs.
......
...@@ -43,7 +43,7 @@ class GetDistributionStrategyTest(tf.test.TestCase): ...@@ -43,7 +43,7 @@ class GetDistributionStrategyTest(tf.test.TestCase):
def test_no_strategy(self): def test_no_strategy(self):
ds = distribute_utils.get_distribution_strategy('off') ds = distribute_utils.get_distribution_strategy('off')
self.assertIsNone(ds) self.assertIs(ds, tf.distribute.get_strategy())
def test_invalid_strategy(self): def test_invalid_strategy(self):
with self.assertRaisesRegexp( with self.assertRaisesRegexp(
......
...@@ -18,9 +18,27 @@ from absl import flags ...@@ -18,9 +18,27 @@ from absl import flags
def define_flags(): def define_flags():
"""Defines flags.""" """Defines flags.
All flags are defined as optional, but in practice most models use some of
these flags and so mark_flags_as_required() should be called after calling
this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
For example:
```
from absl import flags
from official.common import flags as tfm_flags # pylint: disable=line-too-long
...
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
```
The reason all flags are optional is because unit tests often do not set or
use any of the flags.
"""
flags.DEFINE_string( flags.DEFINE_string(
'experiment', default=None, help='The experiment type registered.') 'experiment', default=None, help=
'The experiment type registered, specifying an ExperimentConfig.')
flags.DEFINE_enum( flags.DEFINE_enum(
'mode', 'mode',
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides TFM orbit actions and associated helper functions/classes."""
import os
from typing import List
import gin
import orbit
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.core import base_trainer
from official.core import config_definitions
from official.modeling import optimization
class PruningActions:
"""Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log
pruning metrics to tensorboard.
This action must be used when training a pruned model to avoid pruning error.
"""
def __init__(
self,
export_dir: str,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the pruning summaries.
model: `tf.keras.Model` model instance used for training. This will be
used to assign a pruning step to each prunable weight.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to find the current training steps.
"""
self._optimizer = optimizer
self.update_pruning_step = tfmot.sparsity.keras.UpdatePruningStep()
self.update_pruning_step.set_model(model)
self.update_pruning_step.on_train_begin()
self.pruning_summaries = tfmot.sparsity.keras.PruningSummaries(
log_dir=export_dir)
model.optimizer = optimizer
self.pruning_summaries.set_model(model)
def __call__(self, output: orbit.runner.Output):
"""Update pruning step and log pruning summaries.
Args:
output: The train output to test.
"""
self.update_pruning_step.on_epoch_end(batch=None)
self.pruning_summaries.on_epoch_begin(epoch=None)
class EMACheckpointing:
"""Eval action to save checkpoint with average weights when EMA is used.
This action swaps the weights of the model with the average weights, then it
saves the checkpoint under export_dir/ema_checkpoints. Checkpointing is
expensive for large models, so doing this action in eval is more efficient
than training.
"""
def __init__(self, export_dir: str, optimizer: tf.keras.optimizers.Optimizer,
checkpoint: tf.train.Checkpoint, max_to_keep: int = 1):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the EMA average weights.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to swap the model weights with the average
weigths.
checkpoint: `tf.train.Checkpoint` instance.
max_to_keep: `int` for max checkpoints to keep in ema_checkpoints subdir.
"""
if not isinstance(optimizer, optimization.ExponentialMovingAverage):
raise ValueError('Optimizer has to be instance of'
'optimization.ExponentialMovingAverage for'
'EMACheckpointing action')
export_dir = os.path.join(export_dir, 'ema_checkpoints')
tf.io.gfile.makedirs(
os.path.dirname(export_dir))
self._optimizer = optimizer
self._checkpoint = checkpoint
self._checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=export_dir,
max_to_keep=max_to_keep,
checkpoint_name='average_weights')
def __call__(self, output: orbit.runner.Output):
"""Swaps model weights, and saves the checkpoint.
Args:
output: The train or eval output to test.
"""
self._optimizer.swap_weights()
self._checkpoint_manager.save(checkpoint_number=self._optimizer.iterations)
self._optimizer.swap_weights()
@gin.configurable
def get_eval_actions(
params: config_definitions.ExperimentConfig,
trainer: base_trainer.Trainer,
model_dir: str) -> List[orbit.Action]:
"""Gets eval actions for TFM trainer."""
eval_actions = []
# Adds ema checkpointing action to save the average weights under
# ema_checkpoints subdir.
if isinstance(trainer.optimizer, optimization.ExponentialMovingAverage):
eval_actions.append(
EMACheckpointing(
export_dir=model_dir,
optimizer=trainer.optimizer,
checkpoint=trainer.checkpoint,
max_to_keep=params.trainer.max_to_keep))
return eval_actions
@gin.configurable
def get_train_actions(params: config_definitions.ExperimentConfig,
trainer: base_trainer.Trainer,
model_dir: str) -> List[orbit.Action]:
"""Gets train actions for TFM trainer."""
train_actions = []
# Adds pruning callback actions.
if hasattr(params.task, 'pruning'):
train_actions.append(
PruningActions(
export_dir=model_dir,
model=trainer.model,
optimizer=trainer.optimizer))
return train_actions
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TFM actions."""
import os
from absl.testing import parameterized
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import actions
from official.modeling import optimization
class TestModel(tf.Module):
def __init__(self):
self.value = tf.Variable(0)
@tf.function(input_signature=[])
def __call__(self):
return self.value
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
class ActionsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(all_strategy_combinations())
def test_ema_checkpointing(self, distribution):
with distribution.scope():
directory = self.create_tempdir()
model = TestModel()
optimizer = tf.keras.optimizers.SGD()
optimizer = optimization.ExponentialMovingAverage(
optimizer, trainable_weights_only=False)
# Creats average weights for the model variables. Average weights are
# initialized to zero.
optimizer.shadow_copy(model)
checkpoint = tf.train.Checkpoint(model=model)
# Changes model.value to 3, average value is still 0.
model.value.assign(3)
# Checks model.value is 3
self.assertEqual(model(), 3)
ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint)
ema_action({})
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints')))
checkpoint.read(tf.train.latest_checkpoint(
os.path.join(directory, 'ema_checkpoints')))
# Checks model.value is 0 after swapping.
self.assertEqual(model(), 0)
if __name__ == '__main__':
tf.test.main()
...@@ -79,7 +79,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta): ...@@ -79,7 +79,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate()) optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps # Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations. # avoiding overflow/underflow for float16 computations.
if runtime_config and runtime_config.loss_scale: if runtime_config:
optimizer = performance.configure_optimizer( optimizer = performance.configure_optimizer(
optimizer, optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16", use_float16=runtime_config.mixed_precision_dtype == "float16",
......
...@@ -303,13 +303,16 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase): ...@@ -303,13 +303,16 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
}, },
}))) })))
trainer = self.create_test_trainer(config) trainer = self.create_test_trainer(config)
if mixed_precision_dtype != 'float16': if mixed_precision_dtype == 'float16':
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
else:
self.assertIsInstance(trainer.optimizer, self.assertIsInstance(trainer.optimizer,
tf.keras.mixed_precision.LossScaleOptimizer) tf.keras.mixed_precision.LossScaleOptimizer)
if loss_scale in (None, 'dynamic'):
self.assertTrue(trainer.optimizer.dynamic)
else:
self.assertFalse(trainer.optimizer.dynamic)
self.assertEqual(trainer.optimizer.initial_scale, loss_scale)
else:
self.assertIsInstance(trainer.optimizer, tf.keras.optimizers.SGD)
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32)) metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics) self.assertIn('training_loss', metrics)
......
...@@ -29,12 +29,13 @@ class DataConfig(base_config.Config): ...@@ -29,12 +29,13 @@ class DataConfig(base_config.Config):
"""The base configuration for building datasets. """The base configuration for building datasets.
Attributes: Attributes:
input_path: The path to the input. It can be either (1) a str indicating input_path: The path to the input. It can be either (1) a str indicating a
a file path/pattern, or (2) a str indicating multiple file paths/patterns file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of
(3) a list of str, each of which is a file path/pattern or multiple file str, each of which is a file path/pattern or multiple file paths/patterns
paths/patterns separated by comma. separated by comma, or (4) a dictionary of the previous three approaches
It should not be specified when the following `tfds_name` is specified. for more advanced data mixing using named access. It should not be
specified when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified. specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It tfds_split: A str indicating which split of the data to load from TFDS. It
...@@ -46,8 +47,8 @@ class DataConfig(base_config.Config): ...@@ -46,8 +47,8 @@ class DataConfig(base_config.Config):
shuffle_buffer_size: The buffer size used for shuffling training data. shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. If `True`, we will cache the cache: Whether to cache dataset examples. If `True`, we will cache the
dataset after applying the decode_fn and parse_fn. It can be used to avoid dataset after applying the decode_fn and parse_fn. It can be used to avoid
re-reading from disk, re-decoding and re-parsing the example on the re-reading from disk, re-decoding and re-parsing the example on the second
second epoch, but it requires significant memory overhead. epoch, but it requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when cycle_length: The number of files that will be processed concurrently when
interleaving files. interleaving files.
block_length: The number of consecutive elements to produce from each input block_length: The number of consecutive elements to produce from each input
...@@ -59,11 +60,10 @@ class DataConfig(base_config.Config): ...@@ -59,11 +60,10 @@ class DataConfig(base_config.Config):
tf_data_service_address: The URI of a tf.data service to offload tf_data_service_address: The URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be "protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by `FLAGS.tf_data_service` flag in the binary. overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This tf_data_service_job_name: The name of the tf.data service job. This argument
argument makes it possible for multiple datasets to share the same job. makes it possible for multiple datasets to share the same job. The default
The default behavior is that the dataset creates anonymous, exclusively behavior is that the dataset creates anonymous, exclusively owned jobs.
owned jobs.
tfds_data_dir: A str specifying the directory to read/write TFDS data. tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
returned tf.data.Dataset will have a 2-tuple structure (input, label) returned tf.data.Dataset will have a 2-tuple structure (input, label)
...@@ -75,7 +75,7 @@ class DataConfig(base_config.Config): ...@@ -75,7 +75,7 @@ class DataConfig(base_config.Config):
performance. performance.
seed: An optional seed to use for deterministic shuffling/preprocessing. seed: An optional seed to use for deterministic shuffling/preprocessing.
""" """
input_path: Union[Sequence[str], str] = "" input_path: Union[Sequence[str], str, base_config.Config] = ""
tfds_name: str = "" tfds_name: str = ""
tfds_split: str = "" tfds_split: str = ""
global_batch_size: int = 0 global_batch_size: int = 0
...@@ -239,9 +239,10 @@ class TrainerConfig(base_config.Config): ...@@ -239,9 +239,10 @@ class TrainerConfig(base_config.Config):
@dataclasses.dataclass @dataclasses.dataclass
class TaskConfig(base_config.Config): class TaskConfig(base_config.Config):
init_checkpoint: str = "" init_checkpoint: str = ""
model: base_config.Config = None model: Optional[base_config.Config] = None
train_data: DataConfig = DataConfig() train_data: DataConfig = DataConfig()
validation_data: DataConfig = DataConfig() validation_data: DataConfig = DataConfig()
name: Optional[str] = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -82,7 +82,7 @@ def export(export_module: ExportModule, ...@@ -82,7 +82,7 @@ def export(export_module: ExportModule,
The savedmodel directory path. The savedmodel directory path.
""" """
ckpt_dir_or_file = checkpoint_path ckpt_dir_or_file = checkpoint_path
if tf.io.gfile.isdir(ckpt_dir_or_file): if ckpt_dir_or_file is not None and tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if ckpt_dir_or_file: if ckpt_dir_or_file:
checkpoint = tf.train.Checkpoint(model=export_module.model) checkpoint = tf.train.Checkpoint(model=export_module.model)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""A common dataset reader.""" """A common dataset reader."""
import random import random
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional, Union, Dict, Sequence
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -45,6 +45,7 @@ class InputReader: ...@@ -45,6 +45,7 @@ class InputReader:
params: cfg.DataConfig, params: cfg.DataConfig,
dataset_fn=tf.data.TFRecordDataset, dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None, decoder_fn: Optional[Callable[..., Any]] = None,
combine_fn: Optional[Callable[..., Any]] = None,
sample_fn: Optional[Callable[..., Any]] = None, sample_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None, parser_fn: Optional[Callable[..., Any]] = None,
transform_and_batch_fn: Optional[Callable[ transform_and_batch_fn: Optional[Callable[
...@@ -59,6 +60,9 @@ class InputReader: ...@@ -59,6 +60,9 @@ class InputReader:
example, it can be `tf.data.TFRecordDataset`. example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary. and decodes them into the raw tensor dictionary.
combine_fn: An optional `callable` that takes a dictionarty of
`tf.data.Dataset` objects as input and outputs a combined dataset. It
will be executed after the decoder_fn and before the sample_fn.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn. decoded raw tensors dict before the parser_fn.
...@@ -78,10 +82,23 @@ class InputReader: ...@@ -78,10 +82,23 @@ class InputReader:
raise ValueError('At most one of `input_path` and `tfds_name` can be ' raise ValueError('At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.' % 'specified, but got %s and %s.' %
(params.input_path, params.tfds_name)) (params.input_path, params.tfds_name))
if isinstance(params.input_path,
cfg.base_config.Config) and combine_fn is None:
raise ValueError(
'A `combine_fn` is required if the `input_path` is a dictionary.')
self._tfds_builder = None self._tfds_builder = None
self._matched_files = [] self._matched_files = None
if params.input_path: if params.input_path:
self._matched_files = self._match_files(params.input_path) # we want to combine / mix datasets
if isinstance(params.input_path, cfg.base_config.Config):
self._matched_files = {}
for k, v in params.input_path.as_dict().items():
self._matched_files[k] = self._match_files(v)
# single dataset
else:
self._matched_files = self._match_files(params.input_path)
else: else:
# Read dataset from TFDS. # Read dataset from TFDS.
if not params.tfds_split: if not params.tfds_split:
...@@ -106,18 +123,20 @@ class InputReader: ...@@ -106,18 +123,20 @@ class InputReader:
self._dataset_fn = dataset_fn self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn self._decoder_fn = decoder_fn
self._combine_fn = combine_fn
self._sample_fn = sample_fn self._sample_fn = sample_fn
self._parser_fn = parser_fn self._parser_fn = parser_fn
self._transform_and_batch_fn = transform_and_batch_fn self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn self._postprocess_fn = postprocess_fn
self._seed = params.seed
# When tf.data service is enabled, each data service worker should get # When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None. # different random seeds. Thus, we set `seed` to None.
if params.seed is not None: # Sharding should also be disabled because tf data service handles how
self._seed = params.seed # each worker shard data with `processing_mode` in distribute method.
elif params.enable_tf_data_service: if params.enable_tf_data_service:
self._seed = _get_random_integer()
else:
self._seed = None self._seed = None
self._sharding = False
self._enable_tf_data_service = ( self._enable_tf_data_service = (
params.enable_tf_data_service and params.tf_data_service_address) params.enable_tf_data_service and params.tf_data_service_address)
...@@ -130,7 +149,7 @@ class InputReader: ...@@ -130,7 +149,7 @@ class InputReader:
self._enable_round_robin_tf_data_service = params.get( self._enable_round_robin_tf_data_service = params.get(
'enable_round_robin_tf_data_service', False) 'enable_round_robin_tf_data_service', False)
def _match_files(self, input_path: str) -> List[str]: def _match_files(self, input_path: Union[Sequence[str], str]) -> List[str]:
"""Matches files from an input_path.""" """Matches files from an input_path."""
matched_files = [] matched_files = []
# Read dataset from files. # Read dataset from files.
...@@ -181,16 +200,21 @@ class InputReader: ...@@ -181,16 +200,21 @@ class InputReader:
# If cache is enabled, `reshuffle_each_iteration` is set to False, # If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway. # because we will read the same cached data in every iteration anyway.
if self._is_training: if self._is_training:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if self._sharding and self._seed is None:
seed = _get_random_integer()
else:
seed = self._seed
dataset = dataset.shuffle( dataset = dataset.shuffle(
len(matched_files), len(matched_files),
seed=self._seed, seed=seed,
reshuffle_each_iteration=True if not self._cache else False) reshuffle_each_iteration=True if not self._cache else False)
# Do not enable sharding if tf.data service is enabled, as sharding will be # Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service. # handled inside tf.data service.
if self._sharding and input_context and ( if self._sharding and input_context and (input_context.num_input_pipelines >
input_context.num_input_pipelines > 1 and 1):
not self._enable_tf_data_service):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
...@@ -225,9 +249,8 @@ class InputReader: ...@@ -225,9 +249,8 @@ class InputReader:
dataset = dataset.with_options(options) dataset = dataset.with_options(options)
# Do not enable sharding if tf.data service is enabled, as sharding will be # Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service. # handled inside tf.data service.
if self._sharding and input_context and ( if self._sharding and input_context and (input_context.num_input_pipelines >
input_context.num_input_pipelines > 1 and 1):
not self._enable_tf_data_service):
dataset = dataset.shard(input_context.num_input_pipelines, dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
...@@ -276,42 +299,53 @@ class InputReader: ...@@ -276,42 +299,53 @@ class InputReader:
def _read_decode_and_parse_dataset( def _read_decode_and_parse_dataset(
self, self,
matched_files: List[str], matched_files: Union[Dict[str, List[str]], List[str]],
dataset_fn, dataset_fn,
batch_size: int, batch_size: int,
input_context: Optional[tf.distribute.InputContext] = None, input_context: Optional[tf.distribute.InputContext] = None,
tfds_builder: bool = False) -> tf.data.Dataset: tfds_builder: bool = False) -> tf.data.Dataset:
"""Returns a tf.data.Dataset object after reading, decoding, and parsing.""" """Returns a tf.data.Dataset object after reading, decoding, and parsing."""
def _files_to_dataset(files: List[str]) -> tf.data.Dataset:
if len(files) > 1:
if input_context and (len(files) < input_context.num_input_pipelines):
logging.warn(
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.', len(files),
input_context.num_input_pipelines)
return self._read_files_then_shard(files, dataset_fn, input_context)
else:
return self._shard_files_then_read(files, dataset_fn, input_context)
elif len(files) == 1:
return self._read_files_then_shard(files, dataset_fn, input_context)
else:
raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `files`.')
def _shuffle_and_decode(ds):
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if self._is_training and not self._cache:
ds = ds.shuffle(self._shuffle_buffer_size, seed=self._seed)
# Decode
ds = _maybe_map_fn(ds, self._decoder_fn)
return ds
if tfds_builder: if tfds_builder:
dataset = self._read_tfds(input_context) dataset = self._read_tfds(input_context)
elif len(matched_files) > 1: dataset = _shuffle_and_decode(dataset)
if input_context and (len(matched_files) < elif isinstance(matched_files, (list, tuple)):
input_context.num_input_pipelines): dataset = _files_to_dataset(matched_files)
logging.warn( dataset = _shuffle_and_decode(dataset)
'The number of files %d is less than the number of input pipelines ' elif isinstance(matched_files, dict):
'%d. We will send all input files to every worker. ' datasets = {}
'Please consider sharding your data into more files.', for k, fs in matched_files.items():
len(matched_files), input_context.num_input_pipelines) datasets[k] = _files_to_dataset(fs)
dataset = self._read_files_then_shard(matched_files, datasets[k] = _shuffle_and_decode(datasets[k])
dataset_fn, dataset = self._combine_fn(datasets)
input_context)
else:
dataset = self._shard_files_then_read(matched_files,
dataset_fn,
input_context)
elif len(matched_files) == 1:
dataset = self._read_files_then_shard(matched_files,
dataset_fn,
input_context)
else: else:
raise ValueError('It is unexpected that `tfds_builder` is None and ' raise ValueError('`matched_files` should be a list or dict.')
'there is also no `matched_files`.')
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if self._is_training and not self._cache:
dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
dataset = _maybe_map_fn(dataset, self._decoder_fn)
if self._sample_fn is not None: if self._sample_fn is not None:
dataset = dataset.apply(self._sample_fn) dataset = dataset.apply(self._sample_fn)
dataset = _maybe_map_fn(dataset, self._parser_fn) dataset = _maybe_map_fn(dataset, self._parser_fn)
...@@ -328,8 +362,7 @@ class InputReader: ...@@ -328,8 +362,7 @@ class InputReader:
per_replica_batch_size = input_context.get_per_replica_batch_size( per_replica_batch_size = input_context.get_per_replica_batch_size(
batch_size) if input_context else batch_size batch_size) if input_context else batch_size
dataset = dataset.batch( dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder per_replica_batch_size, drop_remainder=self._drop_remainder)
)
return dataset return dataset
......
...@@ -15,13 +15,15 @@ ...@@ -15,13 +15,15 @@
"""TFM common training driver library.""" """TFM common training driver library."""
# pytype: disable=attribute-error # pytype: disable=attribute-error
import os import os
from typing import Any, Mapping, Tuple, Optional from typing import Any, Mapping, Optional, Tuple
# Import libraries # Import libraries
from absl import logging from absl import logging
import orbit import orbit
import tensorflow as tf import tensorflow as tf
from official.core import actions
from official.core import base_task from official.core import base_task
from official.core import base_trainer from official.core import base_trainer
from official.core import config_definitions from official.core import config_definitions
...@@ -38,7 +40,8 @@ def run_experiment( ...@@ -38,7 +40,8 @@ def run_experiment(
model_dir: str, model_dir: str,
run_post_eval: bool = False, run_post_eval: bool = False,
save_summary: bool = True, save_summary: bool = True,
trainer: Optional[base_trainer.Trainer] = None trainer: Optional[base_trainer.Trainer] = None,
controller_cls=orbit.Controller
) -> Tuple[tf.keras.Model, Mapping[str, Any]]: ) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
...@@ -54,6 +57,8 @@ def run_experiment( ...@@ -54,6 +57,8 @@ def run_experiment(
save_summary: Whether to save train and validation summary. save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope(). strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns: Returns:
A 2-tuple of (model, eval_logs). A 2-tuple of (model, eval_logs).
...@@ -73,6 +78,8 @@ def run_experiment( ...@@ -73,6 +78,8 @@ def run_experiment(
params, model_dir)) params, model_dir))
if trainer.checkpoint: if trainer.checkpoint:
if model_dir is None:
raise ValueError('model_dir must be specified, but got None')
checkpoint_manager = tf.train.CheckpointManager( checkpoint_manager = tf.train.CheckpointManager(
trainer.checkpoint, trainer.checkpoint,
directory=model_dir, directory=model_dir,
...@@ -85,7 +92,7 @@ def run_experiment( ...@@ -85,7 +92,7 @@ def run_experiment(
else: else:
checkpoint_manager = None checkpoint_manager = None
controller = orbit.Controller( controller = controller_cls(
strategy=distribution_strategy, strategy=distribution_strategy,
trainer=trainer if 'train' in mode else None, trainer=trainer if 'train' in mode else None,
evaluator=trainer, evaluator=trainer,
...@@ -97,7 +104,9 @@ def run_experiment( ...@@ -97,7 +104,9 @@ def run_experiment(
params.trainer.validation_summary_subdir) if params.trainer.validation_summary_subdir) if
(save_summary) else None, (save_summary) else None,
summary_interval=params.trainer.summary_interval if summary_interval=params.trainer.summary_interval if
(save_summary) else None) (save_summary) else None,
train_actions=actions.get_train_actions(params, trainer, model_dir),
eval_actions=actions.get_eval_actions(params, trainer, model_dir))
logging.info('Starts to execute mode: %s', mode) logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope(): with distribution_strategy.scope():
...@@ -129,6 +138,11 @@ def run_experiment( ...@@ -129,6 +138,11 @@ def run_experiment(
logging.info('Number of trainable params in model: %f Millions.', logging.info('Number of trainable params in model: %f Millions.',
num_params / 10.**6) num_params / 10.**6)
flops = train_utils.try_count_flops(trainer.model)
if flops is not None:
logging.info('FLOPs (multi-adds) in model: %f Billions.',
flops / 10.**9 / 2)
if run_post_eval: if run_post_eval:
with distribution_strategy.scope(): with distribution_strategy.scope():
return trainer.model, trainer.evaluate( return trainer.model, trainer.evaluate(
......
...@@ -17,7 +17,7 @@ import copy ...@@ -17,7 +17,7 @@ import copy
import json import json
import os import os
import pprint import pprint
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Union
from absl import logging from absl import logging
import dataclasses import dataclasses
...@@ -25,6 +25,9 @@ import gin ...@@ -25,6 +25,9 @@ import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
from official.core import base_task from official.core import base_task
from official.core import base_trainer from official.core import base_trainer
from official.core import config_definitions from official.core import config_definitions
...@@ -139,14 +142,19 @@ class BestCheckpointExporter: ...@@ -139,14 +142,19 @@ class BestCheckpointExporter:
return self._checkpoint_manager return self._checkpoint_manager
def maybe_export_checkpoint(self, checkpoint, eval_logs, global_step): def maybe_export_checkpoint(
self, checkpoint, eval_logs, global_step, write_logs=True) -> bool:
"""Compare eval_logs with past eval_logs and export checkpoint if better."""
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d', logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
eval_logs, global_step) eval_logs, global_step)
if self._best_ckpt_logs is None or self._new_metric_is_better( if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs): self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs self._best_ckpt_logs = eval_logs
self._export_best_eval_metric(checkpoint, self._best_ckpt_logs, if write_logs:
global_step) self.export_best_eval_metric(self._best_ckpt_logs, global_step)
self._get_checkpoint_manager(checkpoint).save()
return True
return False
def _maybe_load_best_eval_metric(self): def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path): if not tf.io.gfile.exists(self.best_ckpt_logs_path):
...@@ -177,7 +185,7 @@ class BestCheckpointExporter: ...@@ -177,7 +185,7 @@ class BestCheckpointExporter:
return True return True
return False return False
def _export_best_eval_metric(self, checkpoint, eval_logs, global_step): def export_best_eval_metric(self, eval_logs, global_step):
"""Export evaluation results of the best checkpoint into a json file.""" """Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext = copy.copy(eval_logs) eval_logs_ext = copy.copy(eval_logs)
eval_logs_ext['best_ckpt_global_step'] = global_step eval_logs_ext['best_ckpt_global_step'] = global_step
...@@ -187,8 +195,6 @@ class BestCheckpointExporter: ...@@ -187,8 +195,6 @@ class BestCheckpointExporter:
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer: with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n') writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
self._get_checkpoint_manager(checkpoint).save()
@property @property
def best_ckpt_logs(self): def best_ckpt_logs(self):
return self._best_ckpt_logs return self._best_ckpt_logs
...@@ -241,6 +247,9 @@ class ParseConfigOptions: ...@@ -241,6 +247,9 @@ class ParseConfigOptions:
def parse_configuration(flags_obj, lock_return=True, print_return=True): def parse_configuration(flags_obj, lock_return=True, print_return=True):
"""Parses ExperimentConfig from flags.""" """Parses ExperimentConfig from flags."""
if flags_obj.experiment is None:
raise ValueError('The flag --experiment must be specified.')
# 1. Get the default config from the registered experiment. # 1. Get the default config from the registered experiment.
params = exp_factory.get_exp_config(flags_obj.experiment) params = exp_factory.get_exp_config(flags_obj.experiment)
...@@ -285,7 +294,7 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True): ...@@ -285,7 +294,7 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
if print_return: if print_return:
pp = pprint.PrettyPrinter() pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters: %s', logging.info('Final experiment parameters:\n%s',
pp.pformat(params.as_dict())) pp.pformat(params.as_dict()))
return params return params
...@@ -294,6 +303,8 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True): ...@@ -294,6 +303,8 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
def serialize_config(params: config_definitions.ExperimentConfig, def serialize_config(params: config_definitions.ExperimentConfig,
model_dir: str): model_dir: str):
"""Serializes and saves the experiment config.""" """Serializes and saves the experiment config."""
if model_dir is None:
raise ValueError('model_dir must be specified, but got None')
params_save_path = os.path.join(model_dir, 'params.yaml') params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path) logging.info('Saving experiment configuration to %s', params_save_path)
tf.io.gfile.makedirs(model_dir) tf.io.gfile.makedirs(model_dir)
...@@ -369,11 +380,15 @@ def remove_ckpts(model_dir): ...@@ -369,11 +380,15 @@ def remove_ckpts(model_dir):
tf.io.gfile.remove(file_to_remove) tf.io.gfile.remove(file_to_remove)
def try_count_params(model: tf.keras.Model): def try_count_params(
model: Union[tf.Module, tf.keras.Model],
trainable_only: bool = False):
"""Count the number of parameters if model is possible. """Count the number of parameters if model is possible.
Args: Args:
model: Try to count the number of params in this model. model: Try to count the number of params in this model.
trainable_only: Whether to calculate trainable params only. This flag is
not used when the model has `count_params` attribute.
Returns: Returns:
The number of parameters or None. The number of parameters or None.
...@@ -387,4 +402,55 @@ def try_count_params(model: tf.keras.Model): ...@@ -387,4 +402,55 @@ def try_count_params(model: tf.keras.Model):
'because the model was not feed any input, e.g., the max ' 'because the model was not feed any input, e.g., the max '
'train step already reached before this run.') 'train step already reached before this run.')
return None return None
else:
total_params = 0
variables = model.trainable_variables if trainable_only else model.variables
for var in variables:
shape = tf.shape(var)
total_params += tf.math.reduce_prod(shape).numpy()
return total_params
def try_count_flops(model: Union[tf.Module, tf.keras.Model],
inputs_kwargs: Optional[Dict[str, Any]] = None):
"""Counts and returns model FLOPs.
Args:
model: A model instance.
inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
shape specifications to getting corresponding concrete function.
Returns:
The model's FLOPs.
"""
if hasattr(model, 'inputs'):
try:
# Get input shape and set batch size to 1.
if model.inputs:
inputs = [
tf.TensorSpec([1] + input.shape[1:], input.dtype)
for input in model.inputs
]
concrete_func = tf.function(model).get_concrete_function(inputs)
# If model.inputs is invalid, try to use the input to get concrete
# function for model.call (subclass model).
else:
concrete_func = tf.function(model.call).get_concrete_function(
**inputs_kwargs)
frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)
# Calculate FLOPs.
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
opts['output'] = 'none'
flops = tf.compat.v1.profiler.profile(
graph=frozen_func.graph, run_meta=run_meta, options=opts)
return flops.total_float_ops
except Exception as e: # pylint: disable=broad-except
logging.info(
'Failed to count model FLOPs with error %s, because the build() '
'methods in keras layers were not called. This is probably because '
'the model was not feed any input, e.g., the max train step already '
'reached before this run.', e)
return None
return None return None
...@@ -23,6 +23,7 @@ from official.modeling import hyperparams ...@@ -23,6 +23,7 @@ from official.modeling import hyperparams
@dataclasses.dataclass @dataclasses.dataclass
class TaskRoutine(hyperparams.Config): class TaskRoutine(hyperparams.Config):
# TODO(hongkuny): deprecate the task_name once we migrated client code.
task_name: str = "" task_name: str = ""
task_config: cfg.TaskConfig = None task_config: cfg.TaskConfig = None
eval_steps: Optional[int] = None eval_steps: Optional[int] = None
...@@ -76,4 +77,4 @@ class MultiEvalExperimentConfig(cfg.ExperimentConfig): ...@@ -76,4 +77,4 @@ class MultiEvalExperimentConfig(cfg.ExperimentConfig):
Attributes: Attributes:
eval_tasks: individual evaluation tasks. eval_tasks: individual evaluation tasks.
""" """
eval_tasks: MultiTaskConfig = MultiTaskConfig() eval_tasks: Tuple[TaskRoutine, ...] = ()
...@@ -16,14 +16,14 @@ ...@@ -16,14 +16,14 @@
The evaluator implements the Orbit `AbstractEvaluator` interface. The evaluator implements the Orbit `AbstractEvaluator` interface.
""" """
from typing import Optional, Union from typing import Dict, List, Optional, Union
import gin import gin
import orbit import orbit
import tensorflow as tf import tensorflow as tf
from official.core import base_task
from official.core import train_utils from official.core import train_utils
from official.modeling.multitask import base_model from official.modeling.multitask import base_model
from official.modeling.multitask import multitask
@gin.configurable @gin.configurable
...@@ -32,37 +32,39 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -32,37 +32,39 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
def __init__( def __init__(
self, self,
task: multitask.MultiTask, eval_tasks: List[base_task.Task],
model: Union[tf.keras.Model, base_model.MultiTaskBaseModel], model: Union[tf.keras.Model, base_model.MultiTaskBaseModel],
global_step: Optional[tf.Variable] = None, global_step: Optional[tf.Variable] = None,
eval_steps: Optional[Dict[str, int]] = None,
checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None): checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
"""Initialize common trainer for TensorFlow models. """Initialize common trainer for TensorFlow models.
Args: Args:
task: A multitask.MultiTask instance. eval_tasks: A list of tasks to evaluate.
model: tf.keras.Model instance. model: tf.keras.Model instance.
global_step: the global step variable. global_step: the global step variable.
eval_steps: a dictionary of steps to run eval keyed by task names.
checkpoint_exporter: an object that has the `maybe_export_checkpoint` checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface. interface.
""" """
# Gets the current distribution strategy. If not inside any strategy scope, # Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy. # it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy() self._strategy = tf.distribute.get_strategy()
self._task = task self._tasks = eval_tasks
self._model = model self._model = model
self._global_step = global_step or orbit.utils.create_global_step() self._global_step = global_step or orbit.utils.create_global_step()
self._checkpoint_exporter = checkpoint_exporter self._checkpoint_exporter = checkpoint_exporter
self._checkpoint = tf.train.Checkpoint( self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step, global_step=self.global_step, model=self.model)
model=self.model)
self._validation_losses = None self._validation_losses = None
self._validation_metrics = None self._validation_metrics = None
# Builds per-task datasets. # Builds per-task datasets.
self.eval_datasets = {} self.eval_datasets = {}
for name, task in self.task.tasks.items(): self.eval_steps = eval_steps or {}
self.eval_datasets[name] = orbit.utils.make_distributed_dataset( for task in self.tasks:
self.eval_datasets[task.name] = orbit.utils.make_distributed_dataset(
self.strategy, task.build_inputs, task.task_config.validation_data) self.strategy, task.build_inputs, task.task_config.validation_data)
# Builds per-task validation loops. # Builds per-task validation loops.
...@@ -89,8 +91,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -89,8 +91,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return orbit.utils.create_loop_fn(eval_step_fn) return orbit.utils.create_loop_fn(eval_step_fn)
self.task_fns = { self.task_fns = {
name: get_function(name, task) task.name: get_function(task.name, task) for task in self.tasks
for name, task in self.task.tasks.items()
} }
@property @property
...@@ -98,8 +99,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -98,8 +99,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return self._strategy return self._strategy
@property @property
def task(self): def tasks(self):
return self._task return self._tasks
@property @property
def model(self): def model(self):
...@@ -115,8 +116,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -115,8 +116,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if self._validation_losses is None: if self._validation_losses is None:
# Builds the per-task metrics and losses. # Builds the per-task metrics and losses.
self._validation_losses = {} self._validation_losses = {}
for name in self.task.tasks: for task in self.tasks:
self._validation_losses[name] = tf.keras.metrics.Mean( self._validation_losses[task.name] = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32) "validation_loss", dtype=tf.float32)
return self._validation_losses return self._validation_losses
...@@ -126,8 +127,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -126,8 +127,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if self._validation_metrics is None: if self._validation_metrics is None:
# Builds the per-task metrics and losses. # Builds the per-task metrics and losses.
self._validation_metrics = {} self._validation_metrics = {}
for name, task in self.task.tasks.items(): for task in self.tasks:
self._validation_metrics[name] = task.build_metrics(training=False) self._validation_metrics[task.name] = task.build_metrics(training=False)
return self._validation_metrics return self._validation_metrics
@property @property
...@@ -145,12 +146,12 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator): ...@@ -145,12 +146,12 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
results = {} results = {}
eval_iters = tf.nest.map_structure(iter, self.eval_datasets) eval_iters = tf.nest.map_structure(iter, self.eval_datasets)
for name, task_eval_loop in self.task_fns.items(): for task in self.tasks:
outputs = None outputs = None
name = task.name
eval_iter = eval_iters[name] eval_iter = eval_iters[name]
task = self.task.tasks[name] task_eval_steps = self.eval_steps.get(name, None) or num_steps
task_eval_steps = self.task.task_eval_steps(name) or num_steps outputs = self.task_fns[name](
outputs = task_eval_loop(
eval_iter, eval_iter,
task_eval_steps, task_eval_steps,
state=outputs, state=outputs,
......
...@@ -22,7 +22,6 @@ from tensorflow.python.distribute import strategy_combinations ...@@ -22,7 +22,6 @@ from tensorflow.python.distribute import strategy_combinations
from official.core import base_task from official.core import base_task
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
from official.modeling.multitask import evaluator from official.modeling.multitask import evaluator
from official.modeling.multitask import multitask
def all_strategy_combinations(): def all_strategy_combinations():
...@@ -89,9 +88,7 @@ class MockTask(base_task.Task): ...@@ -89,9 +88,7 @@ class MockTask(base_task.Task):
np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value])) np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value]))
return state return state
def reduce_aggregated_logs(self, def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
aggregated_logs,
global_step=None):
for k, v in aggregated_logs.items(): for k, v in aggregated_logs.items():
aggregated_logs[k] = np.sum(np.stack(v, axis=0)) aggregated_logs[k] = np.sum(np.stack(v, axis=0))
return aggregated_logs return aggregated_logs
...@@ -106,10 +103,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase): ...@@ -106,10 +103,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask(params=cfg.TaskConfig(), name="bar"), MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo") MockTask(params=cfg.TaskConfig(), name="foo")
] ]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel() model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator( test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model) eval_tasks=tasks, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32)) results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys()) self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys())
self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys()) self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys())
...@@ -123,10 +119,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase): ...@@ -123,10 +119,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask(params=cfg.TaskConfig(), name="bar"), MockTask(params=cfg.TaskConfig(), name="bar"),
MockTask(params=cfg.TaskConfig(), name="foo") MockTask(params=cfg.TaskConfig(), name="foo")
] ]
test_multitask = multitask.MultiTask(tasks=tasks)
model = MockModel() model = MockModel()
test_evaluator = evaluator.MultiTaskEvaluator( test_evaluator = evaluator.MultiTaskEvaluator(
task=test_multitask, model=model) eval_tasks=tasks, model=model)
results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32)) results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(results["bar"]["counter"], self.assertEqual(results["bar"]["counter"],
5. * distribution.num_replicas_in_sync) 5. * distribution.num_replicas_in_sync)
......
...@@ -59,10 +59,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -59,10 +59,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
else: else:
raise ValueError("The tasks argument has an invalid type: %s" % raise ValueError("The tasks argument has an invalid type: %s" %
type(tasks)) type(tasks))
self._task_eval_steps = task_eval_steps or {} self.task_eval_steps = task_eval_steps or {}
self._task_eval_steps = dict([
(name, self._task_eval_steps.get(name, None)) for name in self.tasks
])
self._task_weights = task_weights or {} self._task_weights = task_weights or {}
self._task_weights = dict([ self._task_weights = dict([
(name, self._task_weights.get(name, 1.0)) for name in self.tasks (name, self._task_weights.get(name, 1.0)) for name in self.tasks
...@@ -74,9 +71,9 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -74,9 +71,9 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
task_eval_steps = {} task_eval_steps = {}
task_weights = {} task_weights = {}
for task_routine in config.task_routines: for task_routine in config.task_routines:
task_name = task_routine.task_name task_name = task_routine.task_name or task_routine.task_config.name
tasks[task_name] = task_factory.get_task( tasks[task_name] = task_factory.get_task(
task_routine.task_config, logging_dir=logging_dir) task_routine.task_config, logging_dir=logging_dir, name=task_name)
task_eval_steps[task_name] = task_routine.eval_steps task_eval_steps[task_name] = task_routine.eval_steps
task_weights[task_name] = task_routine.task_weight task_weights[task_name] = task_routine.task_weight
return cls( return cls(
...@@ -86,9 +83,6 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta): ...@@ -86,9 +83,6 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def tasks(self): def tasks(self):
return self._tasks return self._tasks
def task_eval_steps(self, task_name):
return self._task_eval_steps[task_name]
def task_weight(self, task_name): def task_weight(self, task_name):
return self._task_weights[task_name] return self._task_weights[task_name]
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Multitask training driver library.""" """Multitask training driver library."""
# pytype: disable=attribute-error # pytype: disable=attribute-error
import os import os
from typing import Optional from typing import List, Optional
from absl import logging from absl import logging
import orbit import orbit
import tensorflow as tf import tensorflow as tf
...@@ -69,9 +69,11 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy, ...@@ -69,9 +69,11 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
trainer = TRAINERS[params.trainer.trainer_type]( trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None **kwargs) if is_training else None
if is_eval: if is_eval:
eval_steps = task.task_eval_steps
evaluator = evaluator_lib.MultiTaskEvaluator( evaluator = evaluator_lib.MultiTaskEvaluator(
task=task, eval_tasks=task.tasks.values(),
model=model, model=model,
eval_steps=eval_steps,
global_step=trainer.global_step if is_training else None, global_step=trainer.global_step if is_training else None,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir)) params, model_dir))
...@@ -137,7 +139,7 @@ def run_experiment_with_multitask_eval( ...@@ -137,7 +139,7 @@ def run_experiment_with_multitask_eval(
*, *,
distribution_strategy: tf.distribute.Strategy, distribution_strategy: tf.distribute.Strategy,
train_task: base_task.Task, train_task: base_task.Task,
eval_tasks: multitask.MultiTask, eval_tasks: List[base_task.Task],
mode: str, mode: str,
params: configs.MultiEvalExperimentConfig, params: configs.MultiEvalExperimentConfig,
model_dir: str, model_dir: str,
...@@ -149,7 +151,7 @@ def run_experiment_with_multitask_eval( ...@@ -149,7 +151,7 @@ def run_experiment_with_multitask_eval(
Args: Args:
distribution_strategy: A distribution distribution_strategy. distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance. train_task: A base_task.Task instance.
eval_tasks: A multitask.MultiTask with evaluation tasks. eval_tasks: A list of evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'. or 'continuous_eval'.
params: MultiEvalExperimentConfig instance. params: MultiEvalExperimentConfig instance.
...@@ -173,8 +175,8 @@ def run_experiment_with_multitask_eval( ...@@ -173,8 +175,8 @@ def run_experiment_with_multitask_eval(
config=params, config=params,
task=train_task, task=train_task,
model=train_task.build_model(), model=train_task.build_model(),
optimizer=train_task.create_optimizer( optimizer=train_task.create_optimizer(params.trainer.optimizer_config,
params.trainer.optimizer_config, params.runtime), params.runtime),
train=True, train=True,
evaluate=False) evaluate=False)
else: else:
...@@ -182,10 +184,14 @@ def run_experiment_with_multitask_eval( ...@@ -182,10 +184,14 @@ def run_experiment_with_multitask_eval(
model = trainer.model if trainer else train_task.build_model() model = trainer.model if trainer else train_task.build_model()
if is_eval: if is_eval:
eval_steps = dict([(task_routine.task_config.name,
task_routine.eval_steps)
for task_routine in params.eval_tasks])
evaluator = evaluator_lib.MultiTaskEvaluator( evaluator = evaluator_lib.MultiTaskEvaluator(
task=eval_tasks, eval_tasks=eval_tasks,
model=model, model=model,
global_step=trainer.global_step if is_training else None, global_step=trainer.global_step if is_training else None,
eval_steps=eval_steps,
checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter( checkpoint_exporter=train_utils.maybe_create_best_ckpt_exporter(
params, model_dir)) params, model_dir))
else: else:
......
...@@ -65,8 +65,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -65,8 +65,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
task=configs.MultiTaskConfig( task=configs.MultiTaskConfig(
task_routines=( task_routines=(
configs.TaskRoutine( configs.TaskRoutine(
task_name='foo', task_name='foo', task_config=test_utils.FooConfig()),
task_config=test_utils.FooConfig()),
configs.TaskRoutine( configs.TaskRoutine(
task_name='bar', task_config=test_utils.BarConfig())))) task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict( experiment_config = params_dict.override_params_dict(
...@@ -95,18 +94,20 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -95,18 +94,20 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
model_dir = self.get_temp_dir() model_dir = self.get_temp_dir()
experiment_config = configs.MultiEvalExperimentConfig( experiment_config = configs.MultiEvalExperimentConfig(
task=test_utils.FooConfig(), task=test_utils.FooConfig(),
eval_tasks=configs.MultiTaskConfig( eval_tasks=(configs.TaskRoutine(
task_routines=( task_name='foo', task_config=test_utils.FooConfig(), eval_steps=2),
configs.TaskRoutine( configs.TaskRoutine(
task_name='foo', task_name='bar',
task_config=test_utils.FooConfig()), task_config=test_utils.BarConfig(),
configs.TaskRoutine( eval_steps=3)))
task_name='bar', task_config=test_utils.BarConfig()))))
experiment_config = params_dict.override_params_dict( experiment_config = params_dict.override_params_dict(
experiment_config, self._test_config, is_strict=False) experiment_config, self._test_config, is_strict=False)
with distribution_strategy.scope(): with distribution_strategy.scope():
train_task = task_factory.get_task(experiment_config.task) train_task = task_factory.get_task(experiment_config.task)
eval_tasks = multitask.MultiTask.from_config(experiment_config.eval_tasks) eval_tasks = [
task_factory.get_task(config.task_config, name=config.task_name)
for config in experiment_config.eval_tasks
]
train_lib.run_experiment_with_multitask_eval( train_lib.run_experiment_with_multitask_eval(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
train_task=train_task, train_task=train_task,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment