Commit 3d61d6b3 authored by qianyj's avatar qianyj
Browse files

initial files for ResNet50

parent d3a70caf
# 模型名称(此处需修改,用英文全称与简写)
## 模型介绍
使用TensorFlow2进行ResNet50的训练
## 模型结构
ResNet50网络中包含了49个卷积层、1个全连接层等
## 数据集
使用ImageNet数据集,并且需要转成TFRecord格式
ImageNet数据集可以[官网](https://image-net.org/ "ImageNet数据集官网")下载、百度搜索或者联系我们
ImageNet数据集转成TFRecord格式,可以参考以下[script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py)[README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy)
## 训练
### 环境配置
使用[光源](https://www.sourcefind.cn/#/service-details)拉取训练的docker镜像:
* 训练镜像:docker pull image.sourcefind.cn:5000/dcu/admin/base/tensorflow:2.7.0-centos7.6-dtk-22.10.1-py37-latest
python依赖安装:
pip install -r requirement.txt
### fp32训练
#### 单机单卡训练命令:
不打开xla:
export PYTHONPATH=/path/to/ResNet50_TensorFlow2:$PYTHONPATH
python3 official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py --data_dir=/path/to/{ImageNet-tensorflow_data_dir} --model_dir=/path/to/{model_save_dir} --batch_size=128 --num_gpus=1 --use_synthetic_data=false --dtype=fp32
打开xla:
export PYTHONPATH=/path/to/ResNet50_TensorFlow2:$PYTHONPATH
TF_XLA_FLAGS="--tf_xla_auto_jit=2" python3 official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py --data_dir=/path/to/{ImageNet-tensorflow_data_dir} --model_dir=/path/to/{model_save_dir} --batch_size=128 --num_gpus=1 --use_synthetic_data=false --dtype=fp32
#### 单机四卡训练指令:
不打开xla:
export PYTHONPATH=/path/to/ResNet50_TensorFlow2:$PYTHONPATH
python3 official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py --data_dir=/path/to/{ImageNet-tensorflow_data_dir} --model_dir=/path/to/{model_save_dir} --batch_size=512 --num_gpus=4 --use_synthetic_data=false --dtype=fp32
打开xla:
export PYTHONPATH=/path/to/ResNet50_TensorFlow2:$PYTHONPATH
TF_XLA_FLAGS="--tf_xla_auto_jit=2" python3 official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py --data_dir=/path/to/{ImageNet-tensorflow_data_dir} --model_dir=/path/to/{model_save_dir} --batch_size=512 --num_gpus=4 --use_synthetic_data=false --dtype=fp32
#### 多机多卡训练指令(以单机四卡模拟四卡四进程为例):
sed指令只需要执行一次,添加支持多卡运行的代码
sed -i '100 r configfile' models-master/official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
不打开xla:
export PYTHONPATH=/path/to/ResNet50_TensorFlow2:$PYTHONPATH
mpirun -np 4 --hostfile hostfile -mca btl self,tcp --allow-run-as-root --bind-to none scripts-run/single_process.sh
打开xla:
export PYTHONPATH=/path/to/ResNet50_TensorFlow2:$PYTHONPATH
mpirun -np 4 --hostfile hostfile -mca btl self,tcp --allow-run-as-root --bind-to none scripts-run/single_process_xla.sh
### fp16训练
#### 单机单卡训练指令
不打开xla:
export PYTHONPATH=/path/to/ResNet50_TensorFlow2:$PYTHONPATH
python3 official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py --data_dir=/path/to/{ImageNet-tensorflow_data_dir} --model_dir=/path/to/{model_save_dir} --batch_size=128 --num_gpus=1 --use_synthetic_data=false --dtype=fp16
打开xla:
export PYTHONPATH=/path/to/ResNet50_TensorFlow2:$PYTHONPATH
TF_XLA_FLAGS="--tf_xla_auto_jit=2" python3 official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py --data_dir=/path/to/{ImageNet-tensorflow_data_dir} --model_dir=/path/to/{model_save_dir} --batch_size=128 --num_gpus=1 --use_synthetic_data=false --dtype=fp16
#### 单机四卡训练指令
不打开xla:
export PYTHONPATH=/path/to/ResNet50_TensorFlow2:$PYTHONPATH
python3 official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py --data_dir=/path/to/{ImageNet-tensorflow_data_dir} --model_dir=/path/to/{model_save_dir} --batch_size=512 --num_gpus=4 --use_synthetic_data=false --dtype=fp16
打开xla:
export PYTHONPATH=/path/to/ResNet50_TensorFlow2:$PYTHONPATH
TF_XLA_FLAGS="--tf_xla_auto_jit=2" python3 official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py --data_dir=/path/to/{ImageNet-tensorflow_data_dir} --model_dir=/path/to/{model_save_dir} --batch_size=512 --num_gpus=4 --use_synthetic_data=false --dtype=fp16
#### 多机多卡训练指令(以单机四卡模拟四卡四进程为例)
sed指令只需要执行一次,添加支持多卡运行的代码
sed -i '100 r configfile' models-master/official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
修改scripts-run/single_process.sh和scripts-run/single_process_xla.sh文件里的--dtype=fp16
不打开xla:
export PYTHONPATH=/path/to/ResNet50_TensorFlow2:$PYTHONPATH
mpirun -np 4 --hostfile hostfile -mca btl self,tcp --allow-run-as-root --bind-to none scripts-run/single_process.sh
打开xla:
export PYTHONPATH=/path/to/ResNet50_TensorFlow2:$PYTHONPATH
mpirun -np 4 --hostfile hostfile -mca btl self,tcp --allow-run-as-root --bind-to none scripts-run/single_process_xla.sh
## 性能和准确率数据
测试数据:[ImageNet的测试数据集](https://image-net.org/ "ImageNet数据集官网"),使用的加速卡:DCU-Z00-16G
根据模型情况填写表格:
| 卡数 | batch size | 类型 | 性能 | Accuracy | 是否打开xla | | 进程数 |
| :------: | :------: | :------: | :------: |:------: |
| 4 | 512 | fp32 | 843 examples/second | 0.7628 | 否 | 单进程 |
| 4 | 512 | fp16 | - | 0.7616 | 否 | 单进程 |
| 4 | 512 | fp32 | - | 0.7608 | 否 | 四进程 |
| 4 | 512 | fp16 | - | 0.7615 | 否 | 四进程 |
## 参考
* https://github.com/tensorflow/models/tree/master
* https://www.tensorflow.org/api_docs/python/tf/distribute/MultiWorkerMirroredStrategy
worker = []
port_number =40000
filePath='./nodefile'
with open(filePath, 'r') as f:
nodename = f.read().splitlines()
worker_nodes = nodename
num_index=flags_obj.task_index
for node in worker_nodes:
for index in range(4):
worker_sockets = ":".join([node, str(port_number + index )])
worker.append(worker_sockets)
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': worker
},
'task': {'type': 'worker', 'index': num_index}
})
localhost slots=4
localhost
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility library for picking an appropriate dataset function."""
from typing import Any, Callable, Union, Type
import tensorflow as tf
PossibleDatasetType = Union[Type[tf.data.Dataset], Callable[[tf.Tensor], Any]]
def pick_dataset_fn(file_type: str) -> PossibleDatasetType:
if file_type == 'tfrecord':
return tf.data.TFRecordDataset
raise ValueError('Unrecognized file_type: {}'.format(file_type))
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for running models in a distributed setting."""
import json
import os
import tensorflow as tf
def _collective_communication(all_reduce_alg):
"""Return a CollectiveCommunication based on all_reduce_alg.
Args:
all_reduce_alg: a string specifying which collective communication to pick,
or None.
Returns:
tf.distribute.experimental.CollectiveCommunication object
Raises:
ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
"""
collective_communication_options = {
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
}
if all_reduce_alg not in collective_communication_options:
raise ValueError(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
all_reduce_alg))
return collective_communication_options[all_reduce_alg]
def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
"""Return a CrossDeviceOps based on all_reduce_alg and num_packs.
Args:
all_reduce_alg: a string specifying which cross device op to pick, or None.
num_packs: an integer specifying number of packs for the cross device op.
Returns:
tf.distribute.CrossDeviceOps object or None.
Raises:
ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
"""
if all_reduce_alg is None:
return None
mirrored_all_reduce_options = {
"nccl": tf.distribute.NcclAllReduce,
"hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
}
if all_reduce_alg not in mirrored_all_reduce_options:
raise ValueError(
"When used with `mirrored`, valid values for all_reduce_alg are "
"[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
all_reduce_alg))
cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
return cross_device_ops_class(num_packs=num_packs)
def tpu_initialize(tpu_address):
"""Initializes TPU for TF 2.x training.
Args:
tpu_address: string, bns address of master TPU worker.
Returns:
A TPUClusterResolver.
"""
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address)
if tpu_address not in ("", "local"):
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver
def get_distribution_strategy(distribution_strategy="mirrored",
num_gpus=0,
all_reduce_alg=None,
num_packs=1,
tpu_address=None,
**kwargs):
"""Return a DistributionStrategy for running the model.
Args:
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
insensitive. "tpu" means to use TPUStrategy using `tpu_address`.
"off" means to use the default strategy which is obtained from
tf.distribute.get_strategy (for details on the default strategy, see
https://www.tensorflow.org/guide/distributed_training#default_strategy).
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
"ring" and "nccl". If None, DistributionStrategy will choose based on
device topology.
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
tpu_address: Optional. String that represents TPU to connect to. Must not be
None if `distribution_strategy` is set to `tpu`.
**kwargs: Additional kwargs for internal usages.
Returns:
tf.distribute.DistibutionStrategy object.
Raises:
ValueError: if `distribution_strategy` is "off" or "one_device" and
`num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
"""
del kwargs
if num_gpus < 0:
raise ValueError("`num_gpus` can not be negative.")
if not isinstance(distribution_strategy, str):
msg = ("distribution_strategy must be a string but got: %s." %
(distribution_strategy,))
if distribution_strategy == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison
msg += (" If you meant to pass the string 'off', make sure you add "
"quotes around 'off' so that yaml interprets it as a string "
"instead of a bool.")
raise ValueError(msg)
distribution_strategy = distribution_strategy.lower()
if distribution_strategy == "off":
if num_gpus > 1:
raise ValueError(f"When {num_gpus} GPUs are specified, "
"distribution_strategy flag cannot be set to `off`.")
# Return the default distribution strategy.
return tf.distribute.get_strategy()
if distribution_strategy == "tpu":
# When tpu_address is an empty string, we communicate with local TPUs.
cluster_resolver = tpu_initialize(tpu_address)
return tf.distribute.TPUStrategy(cluster_resolver)
if distribution_strategy == "multi_worker_mirrored":
return tf.distribute.experimental.MultiWorkerMirroredStrategy(
communication=_collective_communication(all_reduce_alg))
if distribution_strategy == "one_device":
if num_gpus == 0:
return tf.distribute.OneDeviceStrategy("device:CPU:0")
if num_gpus > 1:
raise ValueError("`OneDeviceStrategy` can not be used for more than "
"one device.")
return tf.distribute.OneDeviceStrategy("device:GPU:0")
if distribution_strategy == "mirrored":
if num_gpus == 0:
devices = ["device:CPU:0"]
else:
devices = ["device:GPU:%d" % i for i in range(num_gpus)]
return tf.distribute.MirroredStrategy(
devices=devices,
cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
if distribution_strategy == "parameter_server":
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
return tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)
raise ValueError("Unrecognized Distribution Strategy: %r" %
distribution_strategy)
def configure_cluster(worker_hosts=None, task_index=-1):
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
Args:
worker_hosts: comma-separated list of worker ip:port pairs.
task_index: index of the worker.
Returns:
Number of workers in the cluster.
"""
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if tf_config:
num_workers = (
len(tf_config["cluster"].get("chief", [])) +
len(tf_config["cluster"].get("worker", [])))
elif worker_hosts:
workers = worker_hosts.split(",")
num_workers = len(workers)
if num_workers > 1 and task_index < 0:
raise ValueError("Must specify task_index when number of workers > 1")
task_index = 0 if num_workers == 1 else task_index
os.environ["TF_CONFIG"] = json.dumps({
"cluster": {
"worker": workers
},
"task": {
"type": "worker",
"index": task_index
}
})
else:
num_workers = 1
return num_workers
def get_strategy_scope(strategy):
if strategy:
strategy_scope = strategy.scope()
else:
strategy_scope = DummyContextManager()
return strategy_scope
class DummyContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for distribution util functions."""
import tensorflow as tf
from official.common import distribute_utils
class DistributeUtilsTest(tf.test.TestCase):
"""Tests for distribute util functions."""
def test_invalid_args(self):
with self.assertRaisesRegex(ValueError, '`num_gpus` can not be negative.'):
_ = distribute_utils.get_distribution_strategy(num_gpus=-1)
with self.assertRaisesRegex(ValueError,
'.*If you meant to pass the string .*'):
_ = distribute_utils.get_distribution_strategy(
distribution_strategy=False, num_gpus=0)
with self.assertRaisesRegex(ValueError, 'When 2 GPUs are specified.*'):
_ = distribute_utils.get_distribution_strategy(
distribution_strategy='off', num_gpus=2)
with self.assertRaisesRegex(ValueError,
'`OneDeviceStrategy` can not be used.*'):
_ = distribute_utils.get_distribution_strategy(
distribution_strategy='one_device', num_gpus=2)
def test_one_device_strategy_cpu(self):
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertIn('CPU', ds.extended.worker_devices[0])
def test_one_device_strategy_gpu(self):
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=1)
self.assertEquals(ds.num_replicas_in_sync, 1)
self.assertEquals(len(ds.extended.worker_devices), 1)
self.assertIn('GPU', ds.extended.worker_devices[0])
def test_mirrored_strategy(self):
ds = distribute_utils.get_distribution_strategy(num_gpus=5)
self.assertEquals(ds.num_replicas_in_sync, 5)
self.assertEquals(len(ds.extended.worker_devices), 5)
for device in ds.extended.worker_devices:
self.assertIn('GPU', device)
_ = distribute_utils.get_distribution_strategy(
distribution_strategy='mirrored',
num_gpus=2,
all_reduce_alg='nccl',
num_packs=2)
with self.assertRaisesRegex(
ValueError,
'When used with `mirrored`, valid values for all_reduce_alg are.*'):
_ = distribute_utils.get_distribution_strategy(
distribution_strategy='mirrored',
num_gpus=2,
all_reduce_alg='dummy',
num_packs=2)
def test_mwms(self):
distribute_utils.configure_cluster(worker_hosts=None, task_index=-1)
ds = distribute_utils.get_distribution_strategy(
'multi_worker_mirrored', all_reduce_alg='nccl')
self.assertIsInstance(
ds, tf.distribute.experimental.MultiWorkerMirroredStrategy)
def test_no_strategy(self):
ds = distribute_utils.get_distribution_strategy('off')
self.assertIs(ds, tf.distribute.get_strategy())
def test_invalid_strategy(self):
with self.assertRaisesRegexp(
ValueError,
'distribution_strategy must be a string but got: False. If'):
distribute_utils.get_distribution_strategy(False)
with self.assertRaisesRegexp(
ValueError, 'distribution_strategy must be a string but got: 1'):
distribute_utils.get_distribution_strategy(1)
def test_get_strategy_scope(self):
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
with distribute_utils.get_strategy_scope(ds):
self.assertIs(tf.distribute.get_strategy(), ds)
with distribute_utils.get_strategy_scope(None):
self.assertIsNot(tf.distribute.get_strategy(), ds)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The central place to define flags."""
from absl import flags
def define_flags():
"""Defines flags.
All flags are defined as optional, but in practice most models use some of
these flags and so mark_flags_as_required() should be called after calling
this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
For example:
```
from absl import flags
from official.common import flags as tfm_flags # pylint: disable=line-too-long
...
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
```
The reason all flags are optional is because unit tests often do not set or
use any of the flags.
"""
flags.DEFINE_string(
'experiment', default=None, help=
'The experiment type registered, specifying an ExperimentConfig.')
flags.DEFINE_enum(
'mode',
default=None,
enum_values=[
'train', 'eval', 'train_and_eval', 'continuous_eval',
'continuous_train_and_eval', 'train_and_validate'
],
help='Mode to run: `train`, `eval`, `train_and_eval`, '
'`continuous_eval`, `continuous_train_and_eval` and '
'`train_and_validate` (which is not implemented in '
'the open source version).')
flags.DEFINE_string(
'model_dir',
default=None,
help='The directory where the model and training/evaluation summaries'
'are stored.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override',
default=None,
help='a YAML/JSON string or a YAML file which specifies additional '
'overrides over the default parameters and those specified in '
'`--config_file`. Note that this is supposed to be used only to override '
'the model parameters, but not the parameters like TPU specific flags. '
'One canonical use case of `--config_file` and `--params_override` is '
'users first define a template config file using `--config_file`, then '
'use `--params_override` to adjust the minimal set of tuning parameters, '
'for example setting up different `train_batch_size`. The final override '
'order of parameters: default_model_params --> params from config_file '
'--> params in params_override. See also the help message of '
'`--config_file`.')
# The libraries rely on gin often make mistakes that include flags inside
# the library files which causes conflicts.
try:
flags.DEFINE_multi_string(
'gin_file', default=None, help='List of paths to the config files.')
except flags.DuplicateFlagError:
pass
try:
flags.DEFINE_multi_string(
'gin_params',
default=None,
help='Newline separated list of Gin parameter bindings.')
except flags.DuplicateFlagError:
pass
flags.DEFINE_string(
'tpu',
default=None,
help='The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.')
flags.DEFINE_string(
'tf_data_service', default=None, help='The tf.data service address')
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
from official.nlp import tasks
from official.nlp.configs import experiment_configs
from official.utils.testing import mock_task
from official.vision import beta
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Global streamz counters."""
from tensorflow.python.eager import monitoring
progressive_policy_creation_counter = monitoring.Counter(
"/tensorflow/training/fast_training/progressive_policy_creation",
"Counter for the number of ProgressivePolicy creations.")
stack_vars_to_vars_call_counter = monitoring.Counter(
"/tensorflow/training/fast_training/tf_vars_to_vars",
"Counter for the number of low-level stacking API calls.")
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides TFM orbit actions and associated helper functions/classes."""
import os
from typing import List
from absl import logging
import gin
import orbit
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.core import base_trainer
from official.core import config_definitions
from official.modeling import optimization
class PruningActions:
"""Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log
pruning metrics to tensorboard.
This action must be used when training a pruned model to avoid pruning error.
"""
def __init__(
self,
export_dir: str,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the pruning summaries.
model: `tf.keras.Model` model instance used for training. This will be
used to assign a pruning step to each prunable weight.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to find the current training steps.
"""
self._optimizer = optimizer
self.update_pruning_step = tfmot.sparsity.keras.UpdatePruningStep()
self.update_pruning_step.set_model(model)
self.update_pruning_step.on_train_begin()
self.pruning_summaries = tfmot.sparsity.keras.PruningSummaries(
log_dir=export_dir)
model.optimizer = optimizer
self.pruning_summaries.set_model(model)
def __call__(self, output: orbit.runner.Output):
"""Update pruning step and log pruning summaries.
Args:
output: The train output to test.
"""
self.update_pruning_step.on_epoch_end(batch=None)
self.pruning_summaries.on_epoch_begin(epoch=None)
class EMACheckpointing:
"""Eval action to save checkpoint with average weights when EMA is used.
This action swaps the weights of the model with the average weights, then it
saves the checkpoint under export_dir/ema_checkpoints. Checkpointing is
expensive for large models, so doing this action in eval is more efficient
than training.
"""
def __init__(self, export_dir: str, optimizer: tf.keras.optimizers.Optimizer,
checkpoint: tf.train.Checkpoint, max_to_keep: int = 1):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the EMA average weights.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to swap the model weights with the average
weigths.
checkpoint: `tf.train.Checkpoint` instance.
max_to_keep: `int` for max checkpoints to keep in ema_checkpoints subdir.
"""
if not isinstance(optimizer, optimization.ExponentialMovingAverage):
raise ValueError('Optimizer has to be instance of'
'optimization.ExponentialMovingAverage for'
'EMACheckpointing action')
export_dir = os.path.join(export_dir, 'ema_checkpoints')
tf.io.gfile.makedirs(
os.path.dirname(export_dir))
self._optimizer = optimizer
self._checkpoint = checkpoint
self._checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=export_dir,
max_to_keep=max_to_keep,
checkpoint_name='average_weights')
def __call__(self, output: orbit.runner.Output):
"""Swaps model weights, and saves the checkpoint.
Args:
output: The train or eval output to test.
"""
self._optimizer.swap_weights()
self._checkpoint_manager.save(checkpoint_number=self._optimizer.iterations)
self._optimizer.swap_weights()
class RecoveryAction:
"""Train action to recover from loss blowup.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def __init__(self, checkpoint_manager: tf.train.CheckpointManager):
self.checkpoint_manager = checkpoint_manager
def __call__(self, _):
"""Recovers the training by triggering checkpoint restoration."""
# Loads the previous good checkpoint.
checkpoint_path = self.checkpoint_manager.restore_or_initialize()
logging.warning('Recovering the model from checkpoint: %s.',
checkpoint_path)
class RecoveryCondition:
"""Recovery Condition."""
def __init__(self,
global_step: tf.Variable,
loss_upper_bound: float,
recovery_begin_steps: int = 0,
recovery_max_trials: int = 3):
self.recover_counter = 0
self.recovery_begin_steps = recovery_begin_steps
self.recovery_max_trials = recovery_max_trials
self.loss_upper_bound = loss_upper_bound
self.global_step = global_step
def __call__(self, outputs: orbit.runner.Output):
loss_value = outputs['training_loss']
if tf.math.is_nan(loss_value):
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
'The loss value is NaN after training loop and it happens %d times.'
% self.recover_counter)
return True
if (self.global_step >= self.recovery_begin_steps and
loss_value > self.loss_upper_bound):
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
f'The loss value is {loss_value}, which is larger than the bound {self.loss_upper_bound}, happens {self.recover_counter} times.'
)
return True
return False
@gin.configurable
def get_eval_actions(
params: config_definitions.ExperimentConfig,
trainer: base_trainer.Trainer,
model_dir: str) -> List[orbit.Action]:
"""Gets eval actions for TFM trainer."""
eval_actions = []
# Adds ema checkpointing action to save the average weights under
# ema_checkpoints subdir.
if isinstance(trainer.optimizer, optimization.ExponentialMovingAverage):
eval_actions.append(
EMACheckpointing(
export_dir=model_dir,
optimizer=trainer.optimizer,
checkpoint=trainer.checkpoint,
max_to_keep=params.trainer.max_to_keep))
return eval_actions
@gin.configurable
def get_train_actions(
params: config_definitions.ExperimentConfig, trainer: base_trainer.Trainer,
model_dir: str,
checkpoint_manager: tf.train.CheckpointManager) -> List[orbit.Action]:
"""Gets train actions for TFM trainer."""
train_actions = []
# Adds pruning callback actions.
if hasattr(params.task, 'pruning'):
train_actions.append(
PruningActions(
export_dir=model_dir,
model=trainer.model,
optimizer=trainer.optimizer))
if params.trainer.recovery_max_trials >= 0:
recovery_condition = RecoveryCondition(
global_step=trainer.global_step,
loss_upper_bound=params.trainer.loss_upper_bound,
recovery_begin_steps=params.trainer.recovery_begin_steps,
recovery_max_trials=params.trainer.recovery_max_trials,
)
recover_action = orbit.actions.ConditionalAction(
condition=recovery_condition,
action=RecoveryAction(checkpoint_manager),
)
train_actions.append(recover_action)
return train_actions
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TFM actions."""
import os
from absl.testing import parameterized
import numpy as np
import orbit
import tensorflow as tf
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import actions
from official.modeling import optimization
class TestModel(tf.Module):
def __init__(self):
self.value = tf.Variable(0)
@tf.function(input_signature=[])
def __call__(self):
return self.value
class ActionsTest(tf.test.TestCase, parameterized.TestCase):
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],))
def test_ema_checkpointing(self, distribution):
with distribution.scope():
directory = self.create_tempdir()
model = TestModel()
optimizer = tf.keras.optimizers.SGD()
optimizer = optimization.ExponentialMovingAverage(
optimizer, trainable_weights_only=False)
# Creats average weights for the model variables. Average weights are
# initialized to zero.
optimizer.shadow_copy(model)
checkpoint = tf.train.Checkpoint(model=model)
# Changes model.value to 3, average value is still 0.
model.value.assign(3)
# Checks model.value is 3
self.assertEqual(model(), 3)
ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint)
ema_action({})
self.assertNotEmpty(
tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints')))
checkpoint.read(tf.train.latest_checkpoint(
os.path.join(directory, 'ema_checkpoints')))
# Checks model.value is 0 after swapping.
self.assertEqual(model(), 0)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],))
def test_recovery_condition(self, distribution):
with distribution.scope():
global_step = orbit.utils.create_global_step()
recover_condition = actions.RecoveryCondition(
global_step, loss_upper_bound=0.5, recovery_max_trials=2)
outputs = {'training_loss': 0.6}
self.assertTrue(recover_condition(outputs))
self.assertTrue(recover_condition(outputs))
with self.assertRaises(RuntimeError):
recover_condition(outputs)
global_step = orbit.utils.create_global_step()
recover_condition = actions.RecoveryCondition(
global_step, loss_upper_bound=0.5, recovery_max_trials=2)
outputs = {'training_loss': tf.constant([np.nan], tf.float32)}
self.assertTrue(recover_condition(outputs))
self.assertTrue(recover_condition(outputs))
with self.assertRaises(RuntimeError):
recover_condition(outputs)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the base task abstraction."""
import abc
from typing import Optional
from absl import logging
import tensorflow as tf
from official.core import config_definitions
from official.modeling import optimization
from official.modeling import performance
OptimizationConfig = optimization.OptimizationConfig
RuntimeConfig = config_definitions.RuntimeConfig
class Task(tf.Module, metaclass=abc.ABCMeta):
"""A single-replica view of training procedure.
Tasks provide artifacts for training/validation procedures, including
loading/iterating over Datasets, training/validation steps, calculating the
loss and customized metrics with reduction.
"""
# Special keys in train/validate step returned logs.
loss = "loss"
def __init__(self,
params,
logging_dir: Optional[str] = None,
name: Optional[str] = None):
"""Task initialization.
Args:
params: the task configuration instance, which can be any of dataclass,
ConfigDict, namedtuple, etc.
logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory.
name: the task name.
"""
super().__init__(name=name)
self._task_config = params
self._logging_dir = logging_dir
@property
def task_config(self):
return self._task_config
@property
def logging_dir(self) -> str:
return self._logging_dir
@classmethod
def create_optimizer(cls, optimizer_config: OptimizationConfig,
runtime_config: Optional[RuntimeConfig] = None):
"""Creates an TF optimizer from configurations.
Args:
optimizer_config: the parameters of the Optimization settings.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory = optimization.OptimizerFactory(optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if runtime_config:
optimizer = performance.configure_optimizer(
optimizer,
use_float16=runtime_config.mixed_precision_dtype == "float16",
loss_scale=runtime_config.loss_scale)
return optimizer
def initialize(self, model: tf.keras.Model):
"""[Optional] A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint is found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. You can use this callback function to load a pretrained
checkpoint, saved under a directory other than the model_dir.
Args:
model: The keras.Model built or used by this task.
"""
ckpt_dir_or_file = self.task_config.init_checkpoint
logging.info("Trying to load pretrained checkpoint from %s",
ckpt_dir_or_file)
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
if hasattr(model, "checkpoint_items"):
checkpoint_items = model.checkpoint_items
else:
checkpoint_items = dict(model=model)
ckpt = tf.train.Checkpoint(**checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info("Finished loading pretrained checkpoint from %s",
ckpt_dir_or_file)
def build_model(self) -> tf.keras.Model:
"""[Optional] Creates model architecture.
Returns:
A model instance.
""" # pytype: disable=bad-return-type # typed-keras
@abc.abstractmethod
def build_inputs(self,
params,
input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a dataset or a nested structure of dataset functions.
Dataset functions define per-host datasets with the per-replica batch size.
With distributed training, this method runs on remote hosts.
Args:
params: hyperparams to create input pipelines, which can be any of
dataclass, ConfigDict, namedtuple, etc.
input_context: optional distribution input pipeline context.
Returns:
A nested structure of per-replica input functions.
"""
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
"""Standard interface to compute losses.
Args:
labels: optional label tensors.
model_outputs: a nested structure of output tensors.
aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
del model_outputs, labels
if aux_losses is None:
losses = [tf.constant(0.0, dtype=tf.float32)]
else:
losses = aux_losses
total_loss = tf.add_n(losses)
return total_loss
def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation."""
del training
return []
def process_metrics(self, metrics, labels, model_outputs, **kwargs):
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects. The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
**kwargs: other args.
"""
for metric in metrics:
metric.update_state(labels, model_outputs)
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
"""Process and update compiled_metrics.
call when using compile/fit API.
Args:
compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
"""
compiled_metrics.update_state(labels, model_outputs)
def train_step(self,
inputs,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics=None):
"""Does forward and backward.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
if isinstance(inputs, tuple) and len(inputs) == 2:
features, labels = inputs
else:
features, labels = inputs, inputs
with tf.GradientTape() as tape:
outputs = model(features, training=True)
# Computes per-replica loss.
if model.compiled_loss:
loss = model.compiled_loss(
labels, outputs, regularization_losses=model.losses)
loss += self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=None)
else:
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
# For mixed precision, when a LossScaleOptimizer is used, the loss is
# scaled to avoid numeric underflow.
if isinstance(optimizer,
tf.keras.mixed_precision.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
if isinstance(optimizer,
tf.keras.mixed_precision.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
if model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics or []})
logs.update({m.name: m.result() for m in model.metrics})
return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
"""Validation step.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
if isinstance(inputs, tuple) and len(inputs) == 2:
features, labels = inputs
else:
features, labels = inputs, inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(
labels=labels, model_outputs=outputs, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
if model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics or []})
logs.update({m.name: m.result() for m in model.metrics})
return logs
def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
Returns:
Model outputs.
"""
return model(inputs, training=False)
def aggregate_logs(self, state, step_logs):
"""Optional aggregation over logs returned from a validation step.
Given step_logs from a validation step, this function aggregates the logs
after each eval_step() (see eval_reduce() function in
official/core/base_trainer.py). It runs on CPU and can be used to aggregate
metrics during validation, when there are too many metrics that cannot fit
into TPU memory. Note that this may increase latency due to data transfer
between TPU and CPU. Also, the step output from a validation step may be a
tuple with elements from replicas, and a concatenation of the elements is
needed in such case.
Args:
state: The current state of training, for example, it can be a sequence of
metrics.
step_logs: Logs from a validation step. Can be a dictionary.
"""
pass
def reduce_aggregated_logs(self,
aggregated_logs,
global_step: Optional[tf.Tensor] = None):
"""Optional reduce of aggregated logs over validation steps.
This function reduces aggregated logs at the end of validation, and can be
used to compute the final metrics. It runs on CPU and in each eval_end() in
base trainer (see eval_end() function in official/core/base_trainer.py).
Args:
aggregated_logs: Aggregated logs over multiple validation steps.
global_step: An optional variable of global step.
Returns:
A dictionary of reduced results.
"""
return {}
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Standard Trainer implementation.
The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
import functools
from typing import Union, Optional
from absl import logging
import gin
import orbit
import tensorflow as tf
from official.core import base_task
from official.core import config_definitions
from official.modeling import optimization
ExperimentConfig = config_definitions.ExperimentConfig
TrainerConfig = config_definitions.TrainerConfig
class Recovery:
"""Built-in model blowup recovery module.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def __init__(self,
loss_upper_bound: float,
checkpoint_manager: tf.train.CheckpointManager,
recovery_begin_steps: int = 0,
recovery_max_trials: int = 3):
self.recover_counter = 0
self.recovery_begin_steps = recovery_begin_steps
self.recovery_max_trials = recovery_max_trials
self.loss_upper_bound = loss_upper_bound
self.checkpoint_manager = checkpoint_manager
def should_recover(self, loss_value, global_step):
if tf.math.is_nan(loss_value):
return True
if (global_step >= self.recovery_begin_steps and
loss_value > self.loss_upper_bound):
return True
return False
def maybe_recover(self, loss_value, global_step):
"""Conditionally recovers the training by triggering checkpoint restoration.
Args:
loss_value: the loss value as a float.
global_step: the number of global training steps.
Raises:
RuntimeError: when recovery happens more than the max number of trials,
the job should crash.
"""
if not self.should_recover(loss_value, global_step):
return
self.recover_counter += 1
if self.recover_counter > self.recovery_max_trials:
raise RuntimeError(
"The loss value is NaN or out of range after training loop and "
f"this happens {self.recover_counter} times.")
# Loads the previous good checkpoint.
checkpoint_path = self.checkpoint_manager.restore_or_initialize()
logging.warning(
"Recovering the model from checkpoint: %s. The loss value becomes "
"%f at step %d.", checkpoint_path, loss_value, global_step)
class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
"""Trainer class for both sync and async Strategy."""
def init_async(self):
"""Initializes the Async Trainer base class."""
assert isinstance(self._strategy, tf.distribute.Strategy)
self._is_async = isinstance(
self._strategy, tf.distribute.experimental.ParameterServerStrategy)
self._coordinator = None
if self._is_async:
self._coordinator = (
tf.distribute.experimental.coordinator.ClusterCoordinator(
self._strategy))
def join(self):
"""Join all async steps. Only useful in aysnc training."""
if getattr(self, "_is_async", False):
self._coordinator.join()
def create_train_loop_fn(self):
"""Creates a eval loop from the given step function and options."""
train_loop_fn = super().create_train_loop_fn()
if getattr(self, "_is_async", False):
def _async_loop_fn(iterator, num_steps):
self._coordinator.schedule(train_loop_fn, args=(iterator, num_steps))
return _async_loop_fn
else:
return train_loop_fn
def create_eval_loop_fn(self, has_state: bool):
"""Creates a training loop from the given step function and options."""
eval_loop_fn = super().create_eval_loop_fn(has_state)
if getattr(self, "_is_async", False):
if has_state:
raise ValueError(
"Stateful eval loop is not supported in async training.")
def _async_loop_fn(iterator, num_steps, state=None, reduce_fn=None):
assert state is None
assert reduce_fn is None
self._coordinator.schedule(eval_loop_fn, args=(iterator, num_steps))
return _async_loop_fn
else:
return eval_loop_fn
def distribute_dataset(self, dataset_or_fn, *args, **kwargs):
"""A utility function to help create a `tf.distribute.DistributedDataset`.
Args:
dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function"
returning a `tf.data.Dataset`. If it is a function, it may optionally
have an argument named `input_context` which will be passed a
`tf.distribute.InputContext` instance.
*args: Any positional arguments to pass through to `dataset_or_fn`.
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
Returns:
A distributed Dataset.
"""
if getattr(self, "_is_async", False):
per_worker_dataset_fn = functools.partial(
orbit.utils.make_distributed_dataset, self._strategy, dataset_or_fn,
*args, **kwargs)
per_worker_dataset_fn = tf.function(per_worker_dataset_fn)
return self._coordinator.create_per_worker_dataset(per_worker_dataset_fn)
else:
return orbit.utils.make_distributed_dataset(self._strategy, dataset_or_fn,
*args, **kwargs)
def get_runtime_options(config: ExperimentConfig):
"""Get tf.distribute.RunOptions from config."""
xla_options = {}
if config.runtime.tpu_enable_xla_dynamic_padder is not None:
xla_options["enable_xla_dynamic_padder"] = (
config.runtime.tpu_enable_xla_dynamic_padder)
return tf.distribute.RunOptions(
experimental_xla_options=tf.tpu.XLAOptions(**xla_options))
@gin.configurable
class Trainer(_AsyncTrainer):
"""Implements the common trainer shared for TensorFlow models."""
# pylint: disable=super-init-not-called
def __init__(
self,
config: ExperimentConfig,
task: base_task.Task,
model: tf.keras.Model,
optimizer: tf.optimizers.Optimizer,
train: bool = True,
evaluate: bool = True,
train_dataset: Optional[Union[tf.data.Dataset,
tf.distribute.DistributedDataset]] = None,
validation_dataset: Optional[Union[
tf.data.Dataset, tf.distribute.DistributedDataset]] = None,
checkpoint_exporter=None):
"""Initialize common trainer for TensorFlow models.
Args:
config: An `ExperimentConfig` instance specifying experiment config.
task: A base_task.Task instance.
model: The model instance, e.g. a tf.keras.Model instance.
optimizer: tf.optimizers.Optimizer instance.
train: bool, whether or not this trainer will be used for training.
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
train_dataset: a dataset object created for training. With tf.distribute,
it needs to be a `DistributedDataset`.
validation_dataset: a dataset object created for evaluation. With
tf.distribute, it needs to be a `DistributedDataset`. The evaluator will
create a dataset iterator for each eval round, so the dataset does not
need to repeat.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self._strategy = tf.distribute.get_strategy()
self._validate_params(
config,
check_train_data=train_dataset is None,
check_validation_data=validation_dataset is None)
self._config = config
self._task = task
self._model = model
self._optimizer = optimizer
self._checkpoint_exporter = checkpoint_exporter
self._recovery = None
# Runtime options are only applied to train_step.
# We use default for eval_step.
self._runtime_options = get_runtime_options(config)
# Creates a shadow copy of the weights to store weights moving average.
if isinstance(self._optimizer, optimization.ExponentialMovingAverage
) and not self._optimizer.has_shadow_copy:
self._optimizer.shadow_copy(self._model)
# global_step increases by 1 after each training iteration.
# We should have global_step.numpy() == self.optimizer.iterations.numpy()
# when there is only 1 optimizer.
self._global_step = orbit.utils.create_global_step()
if hasattr(self.model, "checkpoint_items"):
checkpoint_items = self.model.checkpoint_items
else:
checkpoint_items = {}
self._checkpoint = tf.train.Checkpoint(
global_step=self.global_step,
model=self.model,
optimizer=self.optimizer,
**checkpoint_items)
self._train_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32)
self._validation_loss = tf.keras.metrics.Mean(
"validation_loss", dtype=tf.float32)
model_metrics = model.metrics if hasattr(model, "metrics") else []
self._train_metrics = self.task.build_metrics(
training=True) + model_metrics
self._validation_metrics = self.task.build_metrics(
training=False) + model_metrics
self.init_async()
if train:
train_dataset = train_dataset or self.distribute_dataset(
self.task.build_inputs, self.config.task.train_data)
orbit.StandardTrainer.__init__(
self,
train_dataset,
options=orbit.StandardTrainerOptions(
use_tf_while_loop=config.trainer.train_tf_while_loop,
use_tf_function=config.trainer.train_tf_function,
use_tpu_summary_optimization=config.trainer.allow_tpu_summary))
if evaluate:
validation_dataset = validation_dataset or self.distribute_dataset(
self.task.build_inputs, self.config.task.validation_data)
orbit.StandardEvaluator.__init__(
self,
validation_dataset,
options=orbit.StandardEvaluatorOptions(
use_tf_function=config.trainer.eval_tf_function,
use_tf_while_loop=config.trainer.eval_tf_while_loop))
def _validate_params(self,
config,
check_train_data=True,
check_validation_data=True):
r"""Validates if the configuration object passed to the Trainer.
The experiment configuration should be structured as:
\trainer
\task
\train_data
\validation_data
Args:
config: a namedtuple, dataclass, ConfigDict, etc.
check_train_data: whether to check task.train_data field.
check_validation_data: whether to check task.validation_data field.
"""
if not hasattr(config, "trainer"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `trainer`.")
if not hasattr(config, "task"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `task`.")
if check_train_data and not hasattr(config.task, "train_data"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `task.train_data`.")
if check_validation_data and not hasattr(config.task, "validation_data"):
raise AttributeError("The trainer requires the configuration contains an"
" attribute `task.validation_data`.")
@property
def strategy(self):
return self._strategy
@property
def config(self):
return self._config
@property
def task(self):
return self._task
@property
def model(self):
return self._model
@property
def optimizer(self):
if hasattr(self, "_optimizer"):
return self._optimizer
else:
return None
@property
def global_step(self):
return self._global_step
@property
def train_loss(self):
"""Accesses the training loss metric object."""
return self._train_loss
@property
def validation_loss(self):
"""Accesses the validation loss metric object."""
return self._validation_loss
@property
def train_metrics(self):
"""Accesses all training metric objects."""
return self._train_metrics
@property
def validation_metrics(self):
"""Accesses all validation metric metric objects."""
return self._validation_metrics
def initialize(self):
"""A callback function.
This function will be called when no checkpoint found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. Tasks may use this callback function to load a
pretrained checkpoint, saved under a directory other than the model_dir.
"""
self.task.initialize(self.model)
@property
def checkpoint(self):
"""Accesses the training checkpoint."""
return self._checkpoint
# TODO(yejiayu): Remove this once all deps are fixed.
def add_recovery(self, params: TrainerConfig,
checkpoint_manager: tf.train.CheckpointManager):
if params.recovery_max_trials >= 0:
self._recovery = Recovery(
loss_upper_bound=params.loss_upper_bound,
recovery_begin_steps=params.recovery_begin_steps,
recovery_max_trials=params.recovery_max_trials,
checkpoint_manager=checkpoint_manager)
def train_loop_end(self):
"""See base class."""
self.join()
logs = {}
for metric in self.train_metrics + [self.train_loss]:
logs[metric.name] = metric.result()
metric.reset_states()
if callable(self.optimizer.learning_rate):
# Maybe a self-implemented optimizer does not have `optimizer.iterations`.
# So just to be safe here.
if hasattr(self.optimizer, "iterations"):
logs["learning_rate"] = self.optimizer.learning_rate(
self.optimizer.iterations)
else:
logs["learning_rate"] = self.optimizer.learning_rate(self.global_step)
else:
logs["learning_rate"] = self.optimizer.learning_rate
return logs
def train_step(self, iterator):
"""See base class."""
def step_fn(inputs):
if self.config.runtime.enable_xla and (self.config.runtime.num_gpus > 0):
task_train_step = tf.function(self.task.train_step, jit_compile=True)
else:
task_train_step = self.task.train_step
logs = task_train_step(
inputs,
model=self.model,
optimizer=self.optimizer,
metrics=self.train_metrics)
self._train_loss.update_state(logs[self.task.loss])
self.global_step.assign_add(1)
self.strategy.run(
step_fn, args=(next(iterator),), options=self._runtime_options)
def eval_begin(self):
"""Sets up metrics."""
for metric in self.validation_metrics + [self.validation_loss]:
metric.reset_states()
# Swaps weights to test on weights moving average.
if self.optimizer and isinstance(self.optimizer,
optimization.ExponentialMovingAverage):
self.optimizer.swap_weights()
def eval_step(self, iterator):
"""See base class."""
def step_fn(inputs):
logs = self.task.validation_step(
inputs, model=self.model, metrics=self.validation_metrics)
if self.task.loss in logs:
self._validation_loss.update_state(logs[self.task.loss])
return logs
distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
return tf.nest.map_structure(self.strategy.experimental_local_results,
distributed_outputs)
def eval_end(self, aggregated_logs=None):
"""Processes evaluation results."""
self.join()
logs = {}
for metric in self.validation_metrics:
logs[metric.name] = metric.result()
if self.validation_loss.count.numpy() != 0:
logs[self.validation_loss.name] = self.validation_loss.result()
else:
# `self.validation_loss` metric was not updated, because the validation
# loss was not returned from the task's `validation_step` method.
logging.info("The task did not report validation loss.")
if aggregated_logs:
metrics = self.task.reduce_aggregated_logs(
aggregated_logs, global_step=self.global_step)
logs.update(metrics)
if self._checkpoint_exporter:
self._checkpoint_exporter.maybe_export_checkpoint(
self.checkpoint, logs, self.global_step.numpy())
metric_name = self.config.trainer.best_checkpoint_eval_metric
logs["best_" +
metric_name] = self._checkpoint_exporter.best_ckpt_logs[metric_name]
# Swaps back weights after testing when EMA is used.
# This happens after best checkpoint export so that average weights used for
# eval are exported instead of regular weights.
if self.optimizer and isinstance(self.optimizer,
optimization.ExponentialMovingAverage):
self.optimizer.swap_weights()
return logs
def eval_reduce(self, state=None, step_outputs=None):
return self.task.aggregate_logs(state, step_outputs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment