Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
472e2f80
Commit
472e2f80
authored
Mar 16, 2024
by
zhanggzh
Browse files
Merge remote-tracking branch 'tf_model/main'
parents
d91296eb
f3a14f85
Changes
215
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3680 additions
and
0 deletions
+3680
-0
models-2.13.1/official/common/distribute_utils.py
models-2.13.1/official/common/distribute_utils.py
+233
-0
models-2.13.1/official/common/distribute_utils_test.py
models-2.13.1/official/common/distribute_utils_test.py
+124
-0
models-2.13.1/official/common/flags.py
models-2.13.1/official/common/flags.py
+114
-0
models-2.13.1/official/common/registry_imports.py
models-2.13.1/official/common/registry_imports.py
+20
-0
models-2.13.1/official/common/streamz_counters.py
models-2.13.1/official/common/streamz_counters.py
+27
-0
models-2.13.1/official/core/__init__.py
models-2.13.1/official/core/__init__.py
+31
-0
models-2.13.1/official/core/actions.py
models-2.13.1/official/core/actions.py
+236
-0
models-2.13.1/official/core/actions_test.py
models-2.13.1/official/core/actions_test.py
+131
-0
models-2.13.1/official/core/base_task.py
models-2.13.1/official/core/base_task.py
+358
-0
models-2.13.1/official/core/base_trainer.py
models-2.13.1/official/core/base_trainer.py
+477
-0
models-2.13.1/official/core/base_trainer_test.py
models-2.13.1/official/core/base_trainer_test.py
+363
-0
models-2.13.1/official/core/config_definitions.py
models-2.13.1/official/core/config_definitions.py
+306
-0
models-2.13.1/official/core/exp_factory.py
models-2.13.1/official/core/exp_factory.py
+32
-0
models-2.13.1/official/core/export_base.py
models-2.13.1/official/core/export_base.py
+182
-0
models-2.13.1/official/core/export_base_test.py
models-2.13.1/official/core/export_base_test.py
+133
-0
models-2.13.1/official/core/file_writers.py
models-2.13.1/official/core/file_writers.py
+80
-0
models-2.13.1/official/core/file_writers_test.py
models-2.13.1/official/core/file_writers_test.py
+53
-0
models-2.13.1/official/core/input_reader.py
models-2.13.1/official/core/input_reader.py
+591
-0
models-2.13.1/official/core/registry.py
models-2.13.1/official/core/registry.py
+101
-0
models-2.13.1/official/core/registry_test.py
models-2.13.1/official/core/registry_test.py
+88
-0
No files found.
Too many changes to show.
To preserve performance only
215 of 215+
files are displayed.
Plain diff
Email patch
models-2.13.1/official/common/distribute_utils.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for running models in a distributed setting."""
import
json
import
os
import
tensorflow
as
tf
def
_collective_communication
(
all_reduce_alg
):
"""Return a CollectiveCommunication based on all_reduce_alg.
Args:
all_reduce_alg: a string specifying which collective communication to pick,
or None.
Returns:
tf.distribute.experimental.CollectiveCommunication object
Raises:
ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
"""
collective_communication_options
=
{
None
:
tf
.
distribute
.
experimental
.
CollectiveCommunication
.
AUTO
,
"ring"
:
tf
.
distribute
.
experimental
.
CollectiveCommunication
.
RING
,
"nccl"
:
tf
.
distribute
.
experimental
.
CollectiveCommunication
.
NCCL
}
if
all_reduce_alg
not
in
collective_communication_options
:
raise
ValueError
(
"When used with `multi_worker_mirrored`, valid values for "
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}"
.
format
(
all_reduce_alg
))
return
collective_communication_options
[
all_reduce_alg
]
def
_mirrored_cross_device_ops
(
all_reduce_alg
,
num_packs
):
"""Return a CrossDeviceOps based on all_reduce_alg and num_packs.
Args:
all_reduce_alg: a string specifying which cross device op to pick, or None.
num_packs: an integer specifying number of packs for the cross device op.
Returns:
tf.distribute.CrossDeviceOps object or None.
Raises:
ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
"""
if
all_reduce_alg
is
None
:
return
None
mirrored_all_reduce_options
=
{
"nccl"
:
tf
.
distribute
.
NcclAllReduce
,
"hierarchical_copy"
:
tf
.
distribute
.
HierarchicalCopyAllReduce
}
if
all_reduce_alg
not
in
mirrored_all_reduce_options
:
raise
ValueError
(
"When used with `mirrored`, valid values for all_reduce_alg are "
"[`nccl`, `hierarchical_copy`]. Supplied value: {}"
.
format
(
all_reduce_alg
))
cross_device_ops_class
=
mirrored_all_reduce_options
[
all_reduce_alg
]
return
cross_device_ops_class
(
num_packs
=
num_packs
)
def
tpu_initialize
(
tpu_address
):
"""Initializes TPU for TF 2.x training.
Args:
tpu_address: string, bns address of master TPU worker.
Returns:
A TPUClusterResolver.
"""
cluster_resolver
=
tf
.
distribute
.
cluster_resolver
.
TPUClusterResolver
(
tpu
=
tpu_address
)
if
tpu_address
not
in
(
""
,
"local"
):
tf
.
config
.
experimental_connect_to_cluster
(
cluster_resolver
)
tf
.
tpu
.
experimental
.
initialize_tpu_system
(
cluster_resolver
)
return
cluster_resolver
def
get_distribution_strategy
(
distribution_strategy
=
"mirrored"
,
num_gpus
=
0
,
all_reduce_alg
=
None
,
num_packs
=
1
,
tpu_address
=
None
,
**
kwargs
):
"""Return a Strategy for running the model.
Args:
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
insensitive. "tpu" means to use TPUStrategy using `tpu_address`.
"off" means to use the default strategy which is obtained from
tf.distribute.get_strategy (for details on the default strategy, see
https://www.tensorflow.org/guide/distributed_training#default_strategy).
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
"ring" and "nccl". If None, DistributionStrategy will choose based on
device topology.
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
tpu_address: Optional. String that represents TPU to connect to. Must not be
None if `distribution_strategy` is set to `tpu`.
**kwargs: Additional kwargs for internal usages.
Returns:
tf.distribute.Strategy object.
Raises:
ValueError: if `distribution_strategy` is "off" or "one_device" and
`num_gpus` is larger than 1; or `num_gpus` is negative or if
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
"""
del
kwargs
if
num_gpus
<
0
:
raise
ValueError
(
"`num_gpus` can not be negative."
)
if
not
isinstance
(
distribution_strategy
,
str
):
msg
=
(
"distribution_strategy must be a string but got: %s."
%
(
distribution_strategy
,))
if
distribution_strategy
==
False
:
# pylint: disable=singleton-comparison,g-explicit-bool-comparison
msg
+=
(
" If you meant to pass the string 'off', make sure you add "
"quotes around 'off' so that yaml interprets it as a string "
"instead of a bool."
)
raise
ValueError
(
msg
)
distribution_strategy
=
distribution_strategy
.
lower
()
if
distribution_strategy
==
"off"
:
if
num_gpus
>
1
:
raise
ValueError
(
f
"When
{
num_gpus
}
GPUs are specified, "
"distribution_strategy flag cannot be set to `off`."
)
# Return the default distribution strategy.
return
tf
.
distribute
.
get_strategy
()
if
distribution_strategy
==
"tpu"
:
# When tpu_address is an empty string, we communicate with local TPUs.
cluster_resolver
=
tpu_initialize
(
tpu_address
)
return
tf
.
distribute
.
TPUStrategy
(
cluster_resolver
)
if
distribution_strategy
==
"multi_worker_mirrored"
:
return
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
(
communication
=
_collective_communication
(
all_reduce_alg
))
if
distribution_strategy
==
"one_device"
:
if
num_gpus
==
0
:
return
tf
.
distribute
.
OneDeviceStrategy
(
"device:CPU:0"
)
if
num_gpus
>
1
:
raise
ValueError
(
"`OneDeviceStrategy` can not be used for more than "
"one device."
)
return
tf
.
distribute
.
OneDeviceStrategy
(
"device:GPU:0"
)
if
distribution_strategy
==
"mirrored"
:
if
num_gpus
==
0
:
devices
=
[
"device:CPU:0"
]
else
:
devices
=
[
"device:GPU:%d"
%
i
for
i
in
range
(
num_gpus
)]
return
tf
.
distribute
.
MirroredStrategy
(
devices
=
devices
,
cross_device_ops
=
_mirrored_cross_device_ops
(
all_reduce_alg
,
num_packs
))
if
distribution_strategy
==
"parameter_server"
:
cluster_resolver
=
tf
.
distribute
.
cluster_resolver
.
TFConfigClusterResolver
()
return
tf
.
distribute
.
experimental
.
ParameterServerStrategy
(
cluster_resolver
)
raise
ValueError
(
"Unrecognized Distribution Strategy: %r"
%
distribution_strategy
)
def
configure_cluster
(
worker_hosts
=
None
,
task_index
=-
1
):
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
Args:
worker_hosts: comma-separated list of worker ip:port pairs.
task_index: index of the worker.
Returns:
Number of workers in the cluster.
"""
tf_config
=
json
.
loads
(
os
.
environ
.
get
(
"TF_CONFIG"
,
"{}"
))
if
tf_config
:
num_workers
=
(
len
(
tf_config
[
"cluster"
].
get
(
"chief"
,
[]))
+
len
(
tf_config
[
"cluster"
].
get
(
"worker"
,
[])))
elif
worker_hosts
:
workers
=
worker_hosts
.
split
(
","
)
num_workers
=
len
(
workers
)
if
num_workers
>
1
and
task_index
<
0
:
raise
ValueError
(
"Must specify task_index when number of workers > 1"
)
task_index
=
0
if
num_workers
==
1
else
task_index
os
.
environ
[
"TF_CONFIG"
]
=
json
.
dumps
({
"cluster"
:
{
"worker"
:
workers
},
"task"
:
{
"type"
:
"worker"
,
"index"
:
task_index
}
})
else
:
num_workers
=
1
return
num_workers
def
get_strategy_scope
(
strategy
):
if
strategy
:
strategy_scope
=
strategy
.
scope
()
else
:
strategy_scope
=
DummyContextManager
()
return
strategy_scope
class
DummyContextManager
(
object
):
def
__enter__
(
self
):
pass
def
__exit__
(
self
,
*
args
):
pass
models-2.13.1/official/common/distribute_utils_test.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for distribution util functions."""
import
sys
import
tensorflow
as
tf
from
official.common
import
distribute_utils
TPU_TEST
=
'test_tpu'
in
sys
.
argv
[
0
]
class
DistributeUtilsTest
(
tf
.
test
.
TestCase
):
"""Tests for distribute util functions."""
def
test_invalid_args
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'`num_gpus` can not be negative.'
):
_
=
distribute_utils
.
get_distribution_strategy
(
num_gpus
=-
1
)
with
self
.
assertRaisesRegex
(
ValueError
,
'.*If you meant to pass the string .*'
):
_
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
False
,
num_gpus
=
0
)
with
self
.
assertRaisesRegex
(
ValueError
,
'When 2 GPUs are specified.*'
):
_
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
'off'
,
num_gpus
=
2
)
with
self
.
assertRaisesRegex
(
ValueError
,
'`OneDeviceStrategy` can not be used.*'
):
_
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
'one_device'
,
num_gpus
=
2
)
def
test_one_device_strategy_cpu
(
self
):
ds
=
distribute_utils
.
get_distribution_strategy
(
'one_device'
,
num_gpus
=
0
)
self
.
assertEquals
(
ds
.
num_replicas_in_sync
,
1
)
self
.
assertEquals
(
len
(
ds
.
extended
.
worker_devices
),
1
)
self
.
assertIn
(
'CPU'
,
ds
.
extended
.
worker_devices
[
0
])
def
test_one_device_strategy_gpu
(
self
):
ds
=
distribute_utils
.
get_distribution_strategy
(
'one_device'
,
num_gpus
=
1
)
self
.
assertEquals
(
ds
.
num_replicas_in_sync
,
1
)
self
.
assertEquals
(
len
(
ds
.
extended
.
worker_devices
),
1
)
self
.
assertIn
(
'GPU'
,
ds
.
extended
.
worker_devices
[
0
])
def
test_mirrored_strategy
(
self
):
# CPU only.
_
=
distribute_utils
.
get_distribution_strategy
(
num_gpus
=
0
)
# 5 GPUs.
ds
=
distribute_utils
.
get_distribution_strategy
(
num_gpus
=
5
)
self
.
assertEquals
(
ds
.
num_replicas_in_sync
,
5
)
self
.
assertEquals
(
len
(
ds
.
extended
.
worker_devices
),
5
)
for
device
in
ds
.
extended
.
worker_devices
:
self
.
assertIn
(
'GPU'
,
device
)
_
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
'mirrored'
,
num_gpus
=
2
,
all_reduce_alg
=
'nccl'
,
num_packs
=
2
)
with
self
.
assertRaisesRegex
(
ValueError
,
'When used with `mirrored`, valid values for all_reduce_alg are.*'
):
_
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
'mirrored'
,
num_gpus
=
2
,
all_reduce_alg
=
'dummy'
,
num_packs
=
2
)
def
test_mwms
(
self
):
distribute_utils
.
configure_cluster
(
worker_hosts
=
None
,
task_index
=-
1
)
ds
=
distribute_utils
.
get_distribution_strategy
(
'multi_worker_mirrored'
,
all_reduce_alg
=
'nccl'
)
self
.
assertIsInstance
(
ds
,
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
)
with
self
.
assertRaisesRegex
(
ValueError
,
'When used with `multi_worker_mirrored`, valid values.*'
):
_
=
distribute_utils
.
get_distribution_strategy
(
'multi_worker_mirrored'
,
all_reduce_alg
=
'dummy'
)
def
test_no_strategy
(
self
):
ds
=
distribute_utils
.
get_distribution_strategy
(
'off'
)
self
.
assertIs
(
ds
,
tf
.
distribute
.
get_strategy
())
def
test_tpu_strategy
(
self
):
if
not
TPU_TEST
:
self
.
skipTest
(
'Only Cloud TPU VM instances can have local TPUs.'
)
with
self
.
assertRaises
(
ValueError
):
_
=
distribute_utils
.
get_distribution_strategy
(
'tpu'
)
ds
=
distribute_utils
.
get_distribution_strategy
(
'tpu'
,
tpu_address
=
'local'
)
self
.
assertIsInstance
(
ds
,
tf
.
distribute
.
TPUStrategy
)
def
test_invalid_strategy
(
self
):
with
self
.
assertRaisesRegexp
(
ValueError
,
'distribution_strategy must be a string but got: False. If'
):
distribute_utils
.
get_distribution_strategy
(
False
)
with
self
.
assertRaisesRegexp
(
ValueError
,
'distribution_strategy must be a string but got: 1'
):
distribute_utils
.
get_distribution_strategy
(
1
)
def
test_get_strategy_scope
(
self
):
ds
=
distribute_utils
.
get_distribution_strategy
(
'one_device'
,
num_gpus
=
0
)
with
distribute_utils
.
get_strategy_scope
(
ds
):
self
.
assertIs
(
tf
.
distribute
.
get_strategy
(),
ds
)
with
distribute_utils
.
get_strategy_scope
(
None
):
self
.
assertIsNot
(
tf
.
distribute
.
get_strategy
(),
ds
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
models-2.13.1/official/common/flags.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The central place to define flags."""
from
absl
import
flags
def
define_flags
():
"""Defines flags.
All flags are defined as optional, but in practice most models use some of
these flags and so mark_flags_as_required() should be called after calling
this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
For example:
```
from absl import flags
from official.common import flags as tfm_flags # pylint: disable=line-too-long
...
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
```
The reason all flags are optional is because unit tests often do not set or
use any of the flags.
"""
flags
.
DEFINE_string
(
'experiment'
,
default
=
None
,
help
=
'The experiment type registered, specifying an ExperimentConfig.'
)
flags
.
DEFINE_enum
(
'mode'
,
default
=
None
,
enum_values
=
[
'train'
,
'eval'
,
'train_and_eval'
,
'continuous_eval'
,
'continuous_train_and_eval'
,
'train_and_validate'
,
'train_and_post_eval'
],
help
=
'Mode to run: `train`, `eval`, `train_and_eval`, '
'`continuous_eval`, `continuous_train_and_eval` and '
'`train_and_validate` (which is not implemented in '
'the open source version).'
)
flags
.
DEFINE_string
(
'model_dir'
,
default
=
None
,
help
=
'The directory where the model and training/evaluation summaries'
'are stored.'
)
flags
.
DEFINE_multi_string
(
'config_file'
,
default
=
None
,
help
=
'YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.'
)
flags
.
DEFINE_string
(
'params_override'
,
default
=
None
,
help
=
'a YAML/JSON string or a YAML file which specifies additional '
'overrides over the default parameters and those specified in '
'`--config_file`. Note that this is supposed to be used only to override '
'the model parameters, but not the parameters like TPU specific flags. '
'One canonical use case of `--config_file` and `--params_override` is '
'users first define a template config file using `--config_file`, then '
'use `--params_override` to adjust the minimal set of tuning parameters, '
'for example setting up different `train_batch_size`. The final override '
'order of parameters: default_model_params --> params from config_file '
'--> params in params_override. See also the help message of '
'`--config_file`.'
)
# The libraries rely on gin often make mistakes that include flags inside
# the library files which causes conflicts.
try
:
flags
.
DEFINE_multi_string
(
'gin_file'
,
default
=
None
,
help
=
'List of paths to the config files.'
)
except
flags
.
DuplicateFlagError
:
pass
try
:
flags
.
DEFINE_multi_string
(
'gin_params'
,
default
=
None
,
help
=
'Newline separated list of Gin parameter bindings.'
)
except
flags
.
DuplicateFlagError
:
pass
flags
.
DEFINE_string
(
'tpu'
,
default
=
None
,
help
=
'The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.'
)
flags
.
DEFINE_string
(
'tf_data_service'
,
default
=
None
,
help
=
'The tf.data service address'
)
flags
.
DEFINE_string
(
'tpu_platform'
,
default
=
None
,
help
=
'TPU platform type.'
)
models-2.13.1/official/common/registry_imports.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
from
official
import
vision
from
official.nlp
import
tasks
from
official.nlp.configs
import
experiment_configs
from
official.utils.testing
import
mock_task
models-2.13.1/official/common/streamz_counters.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Global streamz counters."""
from
tensorflow.python.eager
import
monitoring
progressive_policy_creation_counter
=
monitoring
.
Counter
(
"/tensorflow/training/fast_training/progressive_policy_creation"
,
"Counter for the number of ProgressivePolicy creations."
)
stack_vars_to_vars_call_counter
=
monitoring
.
Counter
(
"/tensorflow/training/fast_training/tf_vars_to_vars"
,
"Counter for the number of low-level stacking API calls."
)
models-2.13.1/official/core/__init__.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core is shared by both `nlp` and `vision`."""
from
official.core
import
actions
from
official.core
import
base_task
from
official.core
import
base_trainer
from
official.core
import
config_definitions
from
official.core
import
exp_factory
from
official.core
import
export_base
from
official.core
import
file_writers
from
official.core
import
input_reader
from
official.core
import
registry
from
official.core
import
savedmodel_checkpoint_manager
from
official.core
import
task_factory
from
official.core
import
tf_example_builder
from
official.core
import
tf_example_feature_key
from
official.core
import
train_lib
from
official.core
import
train_utils
models-2.13.1/official/core/actions.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides TFM orbit actions and associated helper functions/classes."""
import
os
from
typing
import
List
from
absl
import
logging
import
gin
import
orbit
import
tensorflow
as
tf
from
official.core
import
base_trainer
from
official.core
import
config_definitions
from
official.modeling
import
optimization
class
PruningAction
:
"""Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log
pruning metrics to tensorboard.
This action must be used when training a pruned model to avoid pruning error.
"""
def
__init__
(
self
,
export_dir
:
str
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the pruning summaries.
model: `tf.keras.Model` model instance used for training. This will be
used to assign a pruning step to each prunable weight.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to find the current training steps.
"""
# TODO(b/221490190): Avoid local import when the bug is fixed.
import
tensorflow_model_optimization
as
tfmot
# pylint: disable=g-import-not-at-top
self
.
_optimizer
=
optimizer
self
.
update_pruning_step
=
tfmot
.
sparsity
.
keras
.
UpdatePruningStep
()
self
.
update_pruning_step
.
set_model
(
model
)
self
.
update_pruning_step
.
on_train_begin
()
self
.
pruning_summaries
=
tfmot
.
sparsity
.
keras
.
PruningSummaries
(
log_dir
=
export_dir
)
model
.
optimizer
=
optimizer
self
.
pruning_summaries
.
set_model
(
model
)
def
__call__
(
self
,
output
:
orbit
.
runner
.
Output
):
"""Update pruning step and log pruning summaries.
Args:
output: The train output.
"""
self
.
update_pruning_step
.
on_epoch_end
(
batch
=
None
)
self
.
pruning_summaries
.
on_epoch_begin
(
epoch
=
None
)
class
EMACheckpointing
:
"""Eval action to save checkpoint with average weights when EMA is used.
This action swaps the weights of the model with the average weights, then it
saves the checkpoint under export_dir/ema_checkpoints. Checkpointing is
expensive for large models, so doing this action in eval is more efficient
than training.
"""
def
__init__
(
self
,
export_dir
:
str
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
checkpoint
:
tf
.
train
.
Checkpoint
,
max_to_keep
:
int
=
1
):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the EMA average weights.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to swap the model weights with the average
weigths.
checkpoint: `tf.train.Checkpoint` instance.
max_to_keep: `int` for max checkpoints to keep in ema_checkpoints subdir.
"""
if
not
isinstance
(
optimizer
,
optimization
.
ExponentialMovingAverage
):
raise
ValueError
(
'Optimizer has to be instance of'
'optimization.ExponentialMovingAverage for'
'EMACheckpointing action'
)
export_dir
=
os
.
path
.
join
(
export_dir
,
'ema_checkpoints'
)
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
export_dir
))
self
.
_optimizer
=
optimizer
self
.
_checkpoint
=
checkpoint
self
.
_checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
export_dir
,
max_to_keep
=
max_to_keep
,
checkpoint_name
=
'average_weights'
)
def
__call__
(
self
,
output
:
orbit
.
runner
.
Output
):
"""Swaps model weights, and saves the checkpoint.
Args:
output: The train or eval output.
"""
self
.
_optimizer
.
swap_weights
()
self
.
_checkpoint_manager
.
save
(
checkpoint_number
=
self
.
_optimizer
.
iterations
)
self
.
_optimizer
.
swap_weights
()
class
RecoveryAction
:
"""Train action to recover from loss blowup.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def
__init__
(
self
,
checkpoint_manager
:
tf
.
train
.
CheckpointManager
):
self
.
checkpoint_manager
=
checkpoint_manager
def
__call__
(
self
,
_
):
"""Recovers the training by triggering checkpoint restoration."""
# Loads the previous good checkpoint.
checkpoint_path
=
self
.
checkpoint_manager
.
restore_or_initialize
()
logging
.
warning
(
'Recovering the model from checkpoint: %s.'
,
checkpoint_path
)
class
RecoveryCondition
:
"""Recovery Condition."""
def
__init__
(
self
,
global_step
:
tf
.
Variable
,
loss_upper_bound
:
float
,
recovery_begin_steps
:
int
=
0
,
recovery_max_trials
:
int
=
3
):
self
.
recover_counter
=
0
self
.
recovery_begin_steps
=
recovery_begin_steps
self
.
recovery_max_trials
=
recovery_max_trials
self
.
loss_upper_bound
=
loss_upper_bound
self
.
global_step
=
global_step
def
__call__
(
self
,
outputs
:
orbit
.
runner
.
Output
):
loss_value
=
outputs
[
'training_loss'
]
if
tf
.
math
.
is_nan
(
loss_value
):
self
.
recover_counter
+=
1
if
self
.
recover_counter
>
self
.
recovery_max_trials
:
raise
RuntimeError
(
'The loss value is NaN after training loop and it happens %d times.'
%
self
.
recover_counter
)
return
True
if
(
self
.
global_step
>=
self
.
recovery_begin_steps
and
loss_value
>
self
.
loss_upper_bound
):
self
.
recover_counter
+=
1
if
self
.
recover_counter
>
self
.
recovery_max_trials
:
raise
RuntimeError
(
f
'The loss value is
{
loss_value
}
, which is larger than the bound
{
self
.
loss_upper_bound
}
, happens
{
self
.
recover_counter
}
times.'
)
return
True
return
False
@
gin
.
configurable
def
get_eval_actions
(
params
:
config_definitions
.
ExperimentConfig
,
trainer
:
base_trainer
.
Trainer
,
model_dir
:
str
)
->
List
[
orbit
.
Action
]:
"""Gets eval actions for TFM trainer."""
eval_actions
=
[]
# Adds ema checkpointing action to save the average weights under
# ema_checkpoints subdir.
if
isinstance
(
trainer
.
optimizer
,
optimization
.
ExponentialMovingAverage
):
eval_actions
.
append
(
EMACheckpointing
(
export_dir
=
model_dir
,
optimizer
=
trainer
.
optimizer
,
checkpoint
=
trainer
.
checkpoint
,
max_to_keep
=
params
.
trainer
.
max_to_keep
))
return
eval_actions
@
gin
.
configurable
def
get_train_actions
(
params
:
config_definitions
.
ExperimentConfig
,
trainer
:
base_trainer
.
Trainer
,
model_dir
:
str
,
checkpoint_manager
:
tf
.
train
.
CheckpointManager
)
->
List
[
orbit
.
Action
]:
"""Gets train actions for TFM trainer."""
train_actions
=
[]
# Adds pruning callback actions.
if
hasattr
(
params
.
task
,
'pruning'
)
and
params
.
task
.
pruning
:
train_actions
.
append
(
PruningAction
(
export_dir
=
model_dir
,
model
=
trainer
.
model
,
optimizer
=
trainer
.
optimizer
))
if
params
.
trainer
.
recovery_max_trials
>=
0
:
recovery_condition
=
RecoveryCondition
(
global_step
=
trainer
.
global_step
,
loss_upper_bound
=
params
.
trainer
.
loss_upper_bound
,
recovery_begin_steps
=
params
.
trainer
.
recovery_begin_steps
,
recovery_max_trials
=
params
.
trainer
.
recovery_max_trials
,
)
recover_action
=
orbit
.
actions
.
ConditionalAction
(
condition
=
recovery_condition
,
action
=
RecoveryAction
(
checkpoint_manager
),
)
train_actions
.
append
(
recover_action
)
if
(
params
.
trainer
.
preemption_on_demand_checkpoint
and
trainer
.
strategy
.
cluster_resolver
):
on_demand_checkpoint_action
=
orbit
.
actions
.
SaveCheckpointIfPreempted
(
trainer
.
strategy
.
cluster_resolver
,
checkpoint_manager
,
trainer
.
global_step
,
keep_running_after_save
=
True
,
)
train_actions
.
append
(
on_demand_checkpoint_action
)
return
train_actions
models-2.13.1/official/core/actions_test.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TFM actions."""
import
os
from
absl.testing
import
parameterized
import
numpy
as
np
import
orbit
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
actions
from
official.modeling
import
optimization
class
TestModel
(
tf
.
keras
.
Model
):
def
__init__
(
self
):
super
().
__init__
()
self
.
value
=
tf
.
Variable
(
0.0
)
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
2
)
_
=
self
.
dense
(
tf
.
zeros
((
2
,
2
),
tf
.
float32
))
def
call
(
self
,
x
,
training
=
None
):
return
self
.
value
+
x
class
ActionsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy
,
],))
def
test_ema_checkpointing
(
self
,
distribution
):
with
distribution
.
scope
():
directory
=
self
.
create_tempdir
()
model
=
TestModel
()
optimizer
=
tf
.
keras
.
optimizers
.
SGD
()
optimizer
=
optimization
.
ExponentialMovingAverage
(
optimizer
,
trainable_weights_only
=
False
)
# Creats average weights for the model variables. Average weights are
# initialized to zero.
optimizer
.
shadow_copy
(
model
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
# Changes model.value to 3, average value is still 0.
model
.
value
.
assign
(
3
)
# Checks model.value is 3
self
.
assertEqual
(
model
(
0.
),
3
)
ema_action
=
actions
.
EMACheckpointing
(
directory
,
optimizer
,
checkpoint
)
ema_action
({})
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
directory
,
'ema_checkpoints'
)))
checkpoint
.
read
(
tf
.
train
.
latest_checkpoint
(
os
.
path
.
join
(
directory
,
'ema_checkpoints'
)))
# Checks model.value is 0 after swapping.
self
.
assertEqual
(
model
(
0.
),
0
)
# Raises an error for a normal optimizer.
with
self
.
assertRaisesRegex
(
ValueError
,
'Optimizer has to be instance of.*'
):
_
=
actions
.
EMACheckpointing
(
directory
,
tf
.
keras
.
optimizers
.
SGD
(),
checkpoint
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],))
def
test_recovery_condition
(
self
,
distribution
):
with
distribution
.
scope
():
global_step
=
orbit
.
utils
.
create_global_step
()
recover_condition
=
actions
.
RecoveryCondition
(
global_step
,
loss_upper_bound
=
0.5
,
recovery_max_trials
=
2
)
outputs
=
{
'training_loss'
:
0.6
}
self
.
assertTrue
(
recover_condition
(
outputs
))
self
.
assertTrue
(
recover_condition
(
outputs
))
with
self
.
assertRaises
(
RuntimeError
):
recover_condition
(
outputs
)
global_step
=
orbit
.
utils
.
create_global_step
()
recover_condition
=
actions
.
RecoveryCondition
(
global_step
,
loss_upper_bound
=
0.5
,
recovery_max_trials
=
2
)
outputs
=
{
'training_loss'
:
tf
.
constant
([
np
.
nan
],
tf
.
float32
)}
self
.
assertTrue
(
recover_condition
(
outputs
))
self
.
assertTrue
(
recover_condition
(
outputs
))
with
self
.
assertRaises
(
RuntimeError
):
recover_condition
(
outputs
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
one_device_strategy_gpu
,
strategy_combinations
.
one_device_strategy
,
],))
def
test_pruning
(
self
,
distribution
):
with
distribution
.
scope
():
directory
=
self
.
get_temp_dir
()
model
=
TestModel
()
optimizer
=
tf
.
keras
.
optimizers
.
SGD
()
pruning
=
actions
.
PruningAction
(
directory
,
model
,
optimizer
)
pruning
({})
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
models-2.13.1/official/core/base_task.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the base task abstraction."""
import
abc
import
functools
from
typing
import
Optional
from
absl
import
logging
import
tensorflow
as
tf
from
official.core
import
config_definitions
from
official.modeling
import
optimization
from
official.modeling
import
performance
from
official.modeling.privacy
import
configs
from
official.modeling.privacy
import
ops
OptimizationConfig
=
optimization
.
OptimizationConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
DifferentialPrivacyConfig
=
configs
.
DifferentialPrivacyConfig
class
Task
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""A single-replica view of training procedure.
Tasks provide artifacts for training/validation procedures, including
loading/iterating over Datasets, training/validation steps, calculating the
loss and customized metrics with reduction.
"""
# Special keys in train/validate step returned logs.
loss
=
"loss"
def
__init__
(
self
,
params
,
logging_dir
:
Optional
[
str
]
=
None
,
name
:
Optional
[
str
]
=
None
):
"""Task initialization.
Args:
params: the task configuration instance, which can be any of dataclass,
ConfigDict, namedtuple, etc.
logging_dir: a string pointing to where the model, summaries etc. will be
saved. You can also write additional stuff in this directory.
name: the task name.
"""
super
().
__init__
(
name
=
name
)
self
.
_task_config
=
params
self
.
_logging_dir
=
logging_dir
@
property
def
task_config
(
self
):
return
self
.
_task_config
@
property
def
logging_dir
(
self
)
->
str
:
return
self
.
_logging_dir
@
classmethod
def
create_optimizer
(
cls
,
optimizer_config
:
OptimizationConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
,
dp_config
:
Optional
[
DifferentialPrivacyConfig
]
=
None
):
"""Creates an TF optimizer from configurations.
Args:
optimizer_config: the parameters of the Optimization settings.
runtime_config: the parameters of the runtime.
dp_config: the parameter of differential privacy.
Returns:
A tf.optimizers.Optimizer object.
"""
gradient_transformers
=
None
if
dp_config
is
not
None
:
logging
.
info
(
"Adding differential privacy transform with config %s."
,
dp_config
.
as_dict
())
noise_stddev
=
dp_config
.
clipping_norm
*
dp_config
.
noise_multiplier
gradient_transformers
=
[
functools
.
partial
(
ops
.
clip_l2_norm
,
l2_norm_clip
=
dp_config
.
clipping_norm
),
functools
.
partial
(
ops
.
add_noise
,
noise_stddev
=
noise_stddev
)
]
opt_factory
=
optimization
.
OptimizerFactory
(
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
(),
gradient_transformers
=
gradient_transformers
)
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if
runtime_config
:
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
use_float16
=
runtime_config
.
mixed_precision_dtype
==
"float16"
,
loss_scale
=
runtime_config
.
loss_scale
)
return
optimizer
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""[Optional] A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint is found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. You can use this callback function to load a pretrained
checkpoint, saved under a directory other than the model_dir.
Args:
model: The keras.Model built or used by this task.
"""
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
logging
.
info
(
"Trying to load pretrained checkpoint from %s"
,
ckpt_dir_or_file
)
if
ckpt_dir_or_file
and
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
logging
.
info
(
"No checkpoint file found from %s. Will not load."
,
ckpt_dir_or_file
)
return
if
hasattr
(
model
,
"checkpoint_items"
):
checkpoint_items
=
model
.
checkpoint_items
else
:
checkpoint_items
=
dict
(
model
=
model
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
checkpoint_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
"Finished loading pretrained checkpoint from %s"
,
ckpt_dir_or_file
)
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
"""[Optional] Creates model architecture.
Returns:
A model instance.
"""
# pytype: disable=bad-return-type # typed-keras
@
abc
.
abstractmethod
def
build_inputs
(
self
,
params
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Returns a dataset or a nested structure of dataset functions.
Dataset functions define per-host datasets with the per-replica batch size.
With distributed training, this method runs on remote hosts.
Args:
params: hyperparams to create input pipelines, which can be any of
dataclass, ConfigDict, namedtuple, etc.
input_context: optional distribution input pipeline context.
Returns:
A nested structure of per-replica input functions.
"""
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
"""Standard interface to compute losses.
Args:
labels: optional label tensors.
model_outputs: a nested structure of output tensors.
aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
del
model_outputs
,
labels
if
aux_losses
is
None
:
losses
=
[
tf
.
constant
(
0.0
,
dtype
=
tf
.
float32
)]
else
:
losses
=
aux_losses
total_loss
=
tf
.
add_n
(
losses
)
return
total_loss
def
build_metrics
(
self
,
training
:
bool
=
True
):
"""Gets streaming metrics for training/validation."""
del
training
return
[]
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
,
**
kwargs
):
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects. The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
**kwargs: other args.
"""
for
metric
in
metrics
:
metric
.
update_state
(
labels
,
model_outputs
)
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
"""Process and update compiled_metrics.
call when using compile/fit API.
Args:
compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors. For example,
output of the keras model built by self.build_model.
"""
compiled_metrics
.
update_state
(
labels
,
model_outputs
)
def
train_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
=
None
):
"""Does forward and backward.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
if
isinstance
(
inputs
,
tuple
)
and
len
(
inputs
)
==
2
:
features
,
labels
=
inputs
else
:
features
,
labels
=
inputs
,
inputs
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
,
training
=
True
)
# Computes per-replica loss.
if
model
.
compiled_loss
:
loss
=
model
.
compiled_loss
(
labels
,
outputs
,
regularization_losses
=
model
.
losses
)
loss
+=
self
.
build_losses
(
labels
=
labels
,
model_outputs
=
outputs
,
aux_losses
=
None
)
else
:
loss
=
self
.
build_losses
(
labels
=
labels
,
model_outputs
=
outputs
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
# For mixed precision, when a LossScaleOptimizer is used, the loss is
# scaled to avoid numeric underflow.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
if
model
.
compiled_metrics
:
self
.
process_compiled_metrics
(
model
.
compiled_metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
or
[]})
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
"""Validation step.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
if
isinstance
(
inputs
,
tuple
)
and
len
(
inputs
)
==
2
:
features
,
labels
=
inputs
else
:
features
,
labels
=
inputs
,
inputs
outputs
=
self
.
inference_step
(
features
,
model
)
loss
=
self
.
build_losses
(
labels
=
labels
,
model_outputs
=
outputs
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
if
model
.
compiled_metrics
:
self
.
process_compiled_metrics
(
model
.
compiled_metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
or
[]})
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
inference_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
):
"""Performs the forward step.
With distribution strategies, this method runs on devices.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
Returns:
Model outputs.
"""
return
model
(
inputs
,
training
=
False
)
def
aggregate_logs
(
self
,
state
,
step_logs
):
"""Optional aggregation over logs returned from a validation step.
Given step_logs from a validation step, this function aggregates the logs
after each eval_step() (see eval_reduce() function in
official/core/base_trainer.py). It runs on CPU and can be used to aggregate
metrics during validation, when there are too many metrics that cannot fit
into TPU memory. Note that this may increase latency due to data transfer
between TPU and CPU. Also, the step output from a validation step may be a
tuple with elements from replicas, and a concatenation of the elements is
needed in such case.
Args:
state: The current state of training, for example, it can be a sequence of
metrics.
step_logs: Logs from a validation step. Can be a dictionary.
"""
pass
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
:
Optional
[
tf
.
Tensor
]
=
None
):
"""Optional reduce of aggregated logs over validation steps.
This function reduces aggregated logs at the end of validation, and can be
used to compute the final metrics. It runs on CPU and in each eval_end() in
base trainer (see eval_end() function in official/core/base_trainer.py).
Args:
aggregated_logs: Aggregated logs over multiple validation steps.
global_step: An optional variable of global step.
Returns:
A dictionary of reduced results.
"""
return
{}
models-2.13.1/official/core/base_trainer.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Standard Trainer implementation.
The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
import
functools
from
typing
import
Union
,
Optional
from
absl
import
logging
import
gin
import
orbit
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.modeling
import
optimization
ExperimentConfig
=
config_definitions
.
ExperimentConfig
TrainerConfig
=
config_definitions
.
TrainerConfig
class
_AsyncTrainer
(
orbit
.
StandardTrainer
,
orbit
.
StandardEvaluator
):
"""Trainer class for both sync and async Strategy."""
def
init_async
(
self
):
"""Initializes the Async Trainer base class."""
assert
isinstance
(
self
.
_strategy
,
tf
.
distribute
.
Strategy
)
self
.
_is_async
=
isinstance
(
self
.
_strategy
,
tf
.
distribute
.
experimental
.
ParameterServerStrategy
)
self
.
_coordinator
=
None
if
self
.
_is_async
:
self
.
_coordinator
=
(
tf
.
distribute
.
experimental
.
coordinator
.
ClusterCoordinator
(
self
.
_strategy
))
def
join
(
self
):
"""Join all async steps. Only useful in aysnc training."""
if
getattr
(
self
,
"_is_async"
,
False
):
self
.
_coordinator
.
join
()
def
create_train_loop_fn
(
self
):
"""Creates a eval loop from the given step function and options."""
train_loop_fn
=
super
().
create_train_loop_fn
()
if
getattr
(
self
,
"_is_async"
,
False
):
def
_async_loop_fn
(
iterator
,
num_steps
):
self
.
_coordinator
.
schedule
(
train_loop_fn
,
args
=
(
iterator
,
num_steps
))
return
_async_loop_fn
else
:
return
train_loop_fn
def
create_eval_loop_fn
(
self
,
has_state
:
bool
):
"""Creates a training loop from the given step function and options."""
eval_loop_fn
=
super
().
create_eval_loop_fn
(
has_state
)
if
getattr
(
self
,
"_is_async"
,
False
):
if
has_state
:
raise
ValueError
(
"Stateful eval loop is not supported in async training."
)
def
_async_loop_fn
(
iterator
,
num_steps
,
state
=
None
,
reduce_fn
=
None
):
assert
state
is
None
assert
reduce_fn
is
None
self
.
_coordinator
.
schedule
(
eval_loop_fn
,
args
=
(
iterator
,
num_steps
))
return
_async_loop_fn
else
:
return
eval_loop_fn
def
distribute_dataset
(
self
,
dataset_or_fn
,
*
args
,
**
kwargs
):
"""A utility function to help create a `tf.distribute.DistributedDataset`.
Args:
dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function"
returning a `tf.data.Dataset`. If it is a function, it may optionally
have an argument named `input_context` which will be passed a
`tf.distribute.InputContext` instance.
*args: Any positional arguments to pass through to `dataset_or_fn`.
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
Returns:
A distributed Dataset.
"""
if
getattr
(
self
,
"_is_async"
,
False
):
per_worker_dataset_fn
=
functools
.
partial
(
orbit
.
utils
.
make_distributed_dataset
,
self
.
_strategy
,
dataset_or_fn
,
*
args
,
**
kwargs
)
per_worker_dataset_fn
=
tf
.
function
(
per_worker_dataset_fn
)
return
self
.
_coordinator
.
create_per_worker_dataset
(
per_worker_dataset_fn
)
else
:
return
orbit
.
utils
.
make_distributed_dataset
(
self
.
_strategy
,
dataset_or_fn
,
*
args
,
**
kwargs
)
def
get_runtime_options
(
config
:
ExperimentConfig
):
"""Get tf.distribute.RunOptions from config."""
xla_options
=
{}
if
config
.
runtime
.
tpu_enable_xla_dynamic_padder
is
not
None
:
xla_options
[
"enable_xla_dynamic_padder"
]
=
(
config
.
runtime
.
tpu_enable_xla_dynamic_padder
)
return
tf
.
distribute
.
RunOptions
(
experimental_xla_options
=
tf
.
tpu
.
XLAOptions
(
**
xla_options
))
@
gin
.
configurable
class
Trainer
(
_AsyncTrainer
):
"""Implements the common trainer shared for TensorFlow models."""
# pylint: disable=super-init-not-called
def
__init__
(
self
,
config
:
ExperimentConfig
,
task
:
base_task
.
Task
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
optimizers
.
Optimizer
,
train
:
bool
=
True
,
evaluate
:
bool
=
True
,
train_dataset
:
Optional
[
Union
[
tf
.
data
.
Dataset
,
tf
.
distribute
.
DistributedDataset
]]
=
None
,
validation_dataset
:
Optional
[
Union
[
tf
.
data
.
Dataset
,
tf
.
distribute
.
DistributedDataset
]]
=
None
,
checkpoint_exporter
=
None
):
"""Initialize common trainer for TensorFlow models.
Args:
config: An `ExperimentConfig` instance specifying experiment config.
task: A base_task.Task instance.
model: The model instance, e.g. a tf.keras.Model instance.
optimizer: tf.optimizers.Optimizer instance.
train: bool, whether or not this trainer will be used for training.
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
train_dataset: a dataset object created for training. With tf.distribute,
it needs to be a `DistributedDataset`.
validation_dataset: a dataset object created for evaluation. With
tf.distribute, it needs to be a `DistributedDataset`. The evaluator will
create a dataset iterator for each eval round, so the dataset does not
need to repeat.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_validate_params
(
config
,
check_train_data
=
train_dataset
is
None
,
check_validation_data
=
validation_dataset
is
None
)
self
.
_config
=
config
self
.
_task
=
task
self
.
_model
=
model
self
.
_optimizer
=
optimizer
self
.
_checkpoint_exporter
=
checkpoint_exporter
self
.
_recovery
=
None
# Runtime options are only applied to train_step.
# We use default for eval_step.
self
.
_runtime_options
=
get_runtime_options
(
config
)
# Creates a shadow copy of the weights to store weights moving average.
if
isinstance
(
self
.
_optimizer
,
optimization
.
ExponentialMovingAverage
)
and
not
self
.
_optimizer
.
has_shadow_copy
:
self
.
_optimizer
.
shadow_copy
(
self
.
_model
)
# global_step increases by 1 after each training iteration.
# We should have global_step.numpy() == self.optimizer.iterations.numpy()
# when there is only 1 optimizer.
self
.
_global_step
=
orbit
.
utils
.
create_global_step
()
if
hasattr
(
self
.
model
,
"checkpoint_items"
):
checkpoint_items
=
self
.
model
.
checkpoint_items
else
:
checkpoint_items
=
{}
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
model
=
self
.
model
,
optimizer
=
self
.
optimizer
,
**
checkpoint_items
)
self
.
_train_loss
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
)
self
.
_validation_loss
=
tf
.
keras
.
metrics
.
Mean
(
"validation_loss"
,
dtype
=
tf
.
float32
)
model_metrics
=
model
.
metrics
if
hasattr
(
model
,
"metrics"
)
else
[]
self
.
init_async
()
if
train
:
self
.
_train_metrics
=
self
.
task
.
build_metrics
(
training
=
True
)
+
model_metrics
train_dataset
=
train_dataset
or
self
.
distribute_dataset
(
self
.
task
.
build_inputs
,
self
.
config
.
task
.
train_data
)
orbit
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
options
=
orbit
.
StandardTrainerOptions
(
use_tf_while_loop
=
config
.
trainer
.
train_tf_while_loop
,
use_tf_function
=
config
.
trainer
.
train_tf_function
,
use_tpu_summary_optimization
=
config
.
trainer
.
allow_tpu_summary
))
if
evaluate
:
self
.
_validation_metrics
=
self
.
task
.
build_metrics
(
training
=
False
)
+
model_metrics
validation_dataset
=
validation_dataset
or
self
.
distribute_dataset
(
self
.
task
.
build_inputs
,
self
.
config
.
task
.
validation_data
)
orbit
.
StandardEvaluator
.
__init__
(
self
,
validation_dataset
,
options
=
orbit
.
StandardEvaluatorOptions
(
use_tf_function
=
config
.
trainer
.
eval_tf_function
,
use_tf_while_loop
=
config
.
trainer
.
eval_tf_while_loop
))
def
_validate_params
(
self
,
config
,
check_train_data
=
True
,
check_validation_data
=
True
):
r
"""Validates if the configuration object passed to the Trainer.
The experiment configuration should be structured as:
\trainer
\task
\train_data
\validation_data
Args:
config: a namedtuple, dataclass, ConfigDict, etc.
check_train_data: whether to check task.train_data field.
check_validation_data: whether to check task.validation_data field.
"""
if
not
hasattr
(
config
,
"trainer"
):
raise
AttributeError
(
"The trainer requires the configuration contains an"
" attribute `trainer`."
)
if
not
hasattr
(
config
,
"task"
):
raise
AttributeError
(
"The trainer requires the configuration contains an"
" attribute `task`."
)
if
check_train_data
and
not
hasattr
(
config
.
task
,
"train_data"
):
raise
AttributeError
(
"The trainer requires the configuration contains an"
" attribute `task.train_data`."
)
if
check_validation_data
and
not
hasattr
(
config
.
task
,
"validation_data"
):
raise
AttributeError
(
"The trainer requires the configuration contains an"
" attribute `task.validation_data`."
)
@
property
def
strategy
(
self
):
return
self
.
_strategy
@
property
def
config
(
self
):
return
self
.
_config
@
property
def
task
(
self
):
return
self
.
_task
@
property
def
model
(
self
):
return
self
.
_model
@
property
def
optimizer
(
self
):
if
hasattr
(
self
,
"_optimizer"
):
return
self
.
_optimizer
else
:
return
None
@
property
def
global_step
(
self
):
return
self
.
_global_step
@
property
def
train_loss
(
self
):
"""Accesses the training loss metric object."""
return
self
.
_train_loss
@
property
def
validation_loss
(
self
):
"""Accesses the validation loss metric object."""
return
self
.
_validation_loss
@
property
def
train_metrics
(
self
):
"""Accesses all training metric objects."""
return
self
.
_train_metrics
@
property
def
validation_metrics
(
self
):
"""Accesses all validation metric metric objects."""
return
self
.
_validation_metrics
def
initialize
(
self
):
"""A callback function.
This function will be called when no checkpoint found for the model.
If there is a checkpoint, the checkpoint will be loaded and this function
will not be called. Tasks may use this callback function to load a
pretrained checkpoint, saved under a directory other than the model_dir.
"""
self
.
task
.
initialize
(
self
.
model
)
@
property
def
checkpoint
(
self
):
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
@
property
def
checkpoint_exporter
(
self
):
"""Accesses the checkpoint exporter."""
return
self
.
_checkpoint_exporter
def
train_loop_end
(
self
):
"""See base class."""
self
.
join
()
logs
=
{}
for
metric
in
self
.
train_metrics
+
[
self
.
train_loss
]:
logs
[
metric
.
name
]
=
metric
.
result
()
metric
.
reset_states
()
if
callable
(
self
.
optimizer
.
learning_rate
):
# Maybe a self-implemented optimizer does not have `optimizer.iterations`.
# So just to be safe here.
if
hasattr
(
self
.
optimizer
,
"iterations"
):
logs
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
(
self
.
optimizer
.
iterations
)
else
:
logs
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
(
self
.
global_step
)
else
:
logs
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
return
logs
def
next_train_inputs
(
self
,
iterator
):
"""Fetches the next inputs for the model during train.
This method consumes the input iterator and returns the next inputs for the
model.
This method provides a way to control how to fetch the next model input, and
what data to send to the model.
This function runs in eager mode.
Args:
iterator: Dataset iterator to generate the next inputs from.
Returns:
The inputs to the model.
"""
return
next
(
iterator
)
def
train_step
(
self
,
iterator
):
"""See base class."""
def
step_fn
(
inputs
):
if
self
.
config
.
runtime
.
enable_xla
and
(
self
.
config
.
runtime
.
num_gpus
>
0
):
task_train_step
=
tf
.
function
(
self
.
task
.
train_step
,
jit_compile
=
True
)
else
:
task_train_step
=
self
.
task
.
train_step
logs
=
task_train_step
(
inputs
,
model
=
self
.
model
,
optimizer
=
self
.
optimizer
,
metrics
=
self
.
train_metrics
)
self
.
_train_loss
.
update_state
(
logs
[
self
.
task
.
loss
])
self
.
global_step
.
assign_add
(
1
)
inputs
=
self
.
next_train_inputs
(
iterator
)
self
.
strategy
.
run
(
step_fn
,
args
=
(
inputs
,),
options
=
self
.
_runtime_options
)
def
eval_begin
(
self
):
"""Sets up metrics."""
for
metric
in
self
.
validation_metrics
+
[
self
.
validation_loss
]:
metric
.
reset_states
()
# Swaps weights to test on weights moving average.
if
self
.
optimizer
and
isinstance
(
self
.
optimizer
,
optimization
.
ExponentialMovingAverage
):
self
.
optimizer
.
swap_weights
()
def
next_eval_inputs
(
self
,
iterator
):
"""Fetches the next inputs for the model during eval.
This method consumes the input iterator and returns the next inputs for the
model and an additional logs dict. The output dict remains in the host (not
sent to GPUs/TPUs) and is merged with the model outputs which will be
processed later in `aggregate_logs`. This is useful for sending extra logs
downstream that are not compatible with the accelerators.
This function runs in eager mode.
Args:
iterator: Dataset iterator to generate the next inputs from.
Returns:
The inputs to the model, and an additional logs dictionnary. The logs
are not passed to the model, instead they are merged with model output
logs.
"""
passthrough_logs
=
dict
()
return
next
(
iterator
),
passthrough_logs
def
eval_step
(
self
,
iterator
):
"""See base class."""
def
step_fn
(
inputs
):
logs
=
self
.
task
.
validation_step
(
inputs
,
model
=
self
.
model
,
metrics
=
self
.
validation_metrics
)
if
self
.
task
.
loss
in
logs
:
self
.
_validation_loss
.
update_state
(
logs
[
self
.
task
.
loss
])
return
logs
inputs
,
passthrough_logs
=
self
.
next_eval_inputs
(
iterator
)
distributed_outputs
=
self
.
strategy
.
run
(
step_fn
,
args
=
(
inputs
,))
logs
=
tf
.
nest
.
map_structure
(
self
.
strategy
.
experimental_local_results
,
distributed_outputs
)
if
set
(
logs
.
keys
())
&
set
(
passthrough_logs
.
keys
()):
logging
.
warning
(
(
"Conflict between the pasthrough log keys and the returned model"
" log keys. Found %r keys in the passthrough logs and %r keys in"
" the model logs. Model log keys takes precedence."
),
logs
.
keys
(),
passthrough_logs
.
keys
(),
)
return
passthrough_logs
|
logs
def
eval_end
(
self
,
aggregated_logs
=
None
):
"""Processes evaluation results."""
self
.
join
()
logs
=
{}
for
metric
in
self
.
validation_metrics
:
logs
[
metric
.
name
]
=
metric
.
result
()
if
self
.
validation_loss
.
count
.
numpy
()
!=
0
:
logs
[
self
.
validation_loss
.
name
]
=
self
.
validation_loss
.
result
()
else
:
# `self.validation_loss` metric was not updated, because the validation
# loss was not returned from the task's `validation_step` method.
logging
.
info
(
"The task did not report validation loss."
)
if
aggregated_logs
:
metrics
=
self
.
task
.
reduce_aggregated_logs
(
aggregated_logs
,
global_step
=
self
.
global_step
)
logs
.
update
(
metrics
)
if
self
.
_checkpoint_exporter
:
self
.
_checkpoint_exporter
.
maybe_export_checkpoint
(
self
.
checkpoint
,
logs
,
self
.
global_step
.
numpy
())
metric_name
=
self
.
config
.
trainer
.
best_checkpoint_eval_metric
logs
[
"best_"
+
metric_name
]
=
self
.
_checkpoint_exporter
.
best_ckpt_logs
[
metric_name
]
# Swaps back weights after testing when EMA is used.
# This happens after best checkpoint export so that average weights used for
# eval are exported instead of regular weights.
if
self
.
optimizer
and
isinstance
(
self
.
optimizer
,
optimization
.
ExponentialMovingAverage
):
self
.
optimizer
.
swap_weights
()
return
logs
def
eval_reduce
(
self
,
state
=
None
,
step_outputs
=
None
):
return
self
.
task
.
aggregate_logs
(
state
,
step_outputs
)
models-2.13.1/official/core/base_trainer_test.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
import
gc
import
multiprocessing
import
os
import
sys
from
absl.testing
import
parameterized
import
orbit
import
portpicker
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
base_trainer
as
trainer_lib
from
official.core
import
config_definitions
as
cfg
from
official.core
import
train_lib
from
official.utils.testing
import
mock_task
TPU_TEST
=
'test_tpu'
in
sys
.
argv
[
0
]
GPU_TEST
=
'test_gpu'
in
sys
.
argv
[
0
]
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],)
def
create_in_process_cluster
(
num_workers
,
num_ps
):
"""Creates and starts local servers and returns the cluster_resolver."""
worker_ports
=
[
portpicker
.
pick_unused_port
()
for
_
in
range
(
num_workers
)]
ps_ports
=
[
portpicker
.
pick_unused_port
()
for
_
in
range
(
num_ps
)]
cluster_dict
=
{}
cluster_dict
[
'worker'
]
=
[
'localhost:%s'
%
port
for
port
in
worker_ports
]
if
num_ps
>
0
:
cluster_dict
[
'ps'
]
=
[
'localhost:%s'
%
port
for
port
in
ps_ports
]
cluster_spec
=
tf
.
train
.
ClusterSpec
(
cluster_dict
)
# Workers need some inter_ops threads to work properly.
worker_config
=
tf
.
compat
.
v1
.
ConfigProto
()
if
multiprocessing
.
cpu_count
()
<
num_workers
+
1
:
worker_config
.
inter_op_parallelism_threads
=
num_workers
+
1
for
i
in
range
(
num_workers
):
tf
.
distribute
.
Server
(
cluster_spec
,
job_name
=
'worker'
,
task_index
=
i
,
config
=
worker_config
,
protocol
=
'grpc'
)
for
i
in
range
(
num_ps
):
tf
.
distribute
.
Server
(
cluster_spec
,
job_name
=
'ps'
,
task_index
=
i
,
protocol
=
'grpc'
)
cluster_resolver
=
tf
.
distribute
.
cluster_resolver
.
SimpleClusterResolver
(
cluster_spec
,
rpc_layer
=
'grpc'
)
return
cluster_resolver
def
dataset_fn
(
input_context
=
None
):
del
input_context
def
dummy_data
(
_
):
return
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
float32
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
class
MockAsyncTrainer
(
trainer_lib
.
_AsyncTrainer
):
"""Mock AsyncTrainer to test the _AsyncTrainer class."""
def
__init__
(
self
):
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
init_async
()
self
.
global_step
=
tf
.
Variable
(
0
,
dtype
=
tf
.
int64
,
name
=
'global_step'
,
trainable
=
False
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
)
self
.
eval_global_step
=
tf
.
Variable
(
0
,
dtype
=
tf
.
int64
,
name
=
'eval_global_step'
,
trainable
=
False
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
)
train_dataset
=
self
.
distribute_dataset
(
dataset_fn
)
orbit
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
options
=
orbit
.
StandardTrainerOptions
())
validation_dataset
=
self
.
distribute_dataset
(
dataset_fn
)
orbit
.
StandardEvaluator
.
__init__
(
self
,
validation_dataset
,
options
=
orbit
.
StandardEvaluatorOptions
(
use_tf_while_loop
=
True
))
def
train_loop_begin
(
self
):
self
.
global_step
.
assign
(
0
)
def
train_step
(
self
,
iterator
):
def
replica_step
(
_
):
self
.
global_step
.
assign_add
(
1
)
self
.
_strategy
.
run
(
replica_step
,
args
=
(
next
(
iterator
),))
def
train_loop_end
(
self
):
self
.
join
()
return
self
.
global_step
.
numpy
()
def
eval_begin
(
self
):
self
.
eval_global_step
.
assign
(
0
)
def
eval_step
(
self
,
iterator
):
def
replica_step
(
_
):
self
.
eval_global_step
.
assign_add
(
1
)
self
.
_strategy
.
run
(
replica_step
,
args
=
(
next
(
iterator
),))
def
eval_end
(
self
):
self
.
join
()
return
self
.
eval_global_step
.
numpy
()
class
TrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_config
=
cfg
.
ExperimentConfig
(
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
cfg
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
},
'learning_rate'
:
{
'type'
:
'constant'
}
})))
def
tearDown
(
self
):
gc
.
collect
()
# This will only contain uncollectable garbage, i.e. reference cycles
# involving objects with __del__ defined.
self
.
assertEmpty
(
gc
.
garbage
)
super
().
tearDown
()
def
create_test_trainer
(
self
,
config
,
model_dir
=
None
,
task
=
None
):
task
=
task
or
mock_task
.
MockTask
(
config
.
task
,
logging_dir
=
model_dir
)
ckpt_exporter
=
train_lib
.
maybe_create_best_ckpt_exporter
(
config
,
model_dir
)
trainer
=
trainer_lib
.
Trainer
(
config
,
task
,
model
=
task
.
build_model
(),
optimizer
=
task
.
create_optimizer
(
config
.
trainer
.
optimizer_config
,
config
.
runtime
),
checkpoint_exporter
=
ckpt_exporter
)
return
trainer
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_train
(
self
,
distribution
):
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
self
.
_config
)
logs
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
logs
)
self
.
assertIn
(
'learning_rate'
,
logs
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_passing_datasets
(
self
,
distribution
):
with
distribution
.
scope
():
task
=
mock_task
.
MockTask
(
self
.
_config
)
train_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
distribution
,
task
.
build_inputs
,
self
.
_config
.
task
.
train_data
)
validation_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
distribution
,
task
.
build_inputs
,
self
.
_config
.
task
.
validation_data
)
self
.
_config
.
task
.
train_data
=
None
self
.
_config
.
task
.
validation_data
=
None
trainer
=
trainer_lib
.
Trainer
(
self
.
_config
,
task
,
model
=
task
.
build_model
(),
optimizer
=
task
.
create_optimizer
(
self
.
_config
.
trainer
.
optimizer_config
,
self
.
_config
.
runtime
),
train_dataset
=
train_dataset
,
validation_dataset
=
validation_dataset
)
logs
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
logs
)
self
.
assertIn
(
'learning_rate'
,
logs
)
logs
=
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'validation_loss'
,
logs
)
def
test_base_async_trainer
(
self
):
if
TPU_TEST
or
GPU_TEST
:
self
.
skipTest
(
'Aysnc training is not available on GPU/GPU.'
)
num_workers
=
3
num_ps
=
2
cluster_resolver
=
create_in_process_cluster
(
num_workers
,
num_ps
)
distribution
=
tf
.
distribute
.
experimental
.
ParameterServerStrategy
(
cluster_resolver
)
with
distribution
.
scope
():
trainer
=
MockAsyncTrainer
()
trainer
.
init_async
()
self
.
assertIsInstance
(
trainer
.
_coordinator
,
tf
.
distribute
.
experimental
.
coordinator
.
ClusterCoordinator
)
self
.
assertEqual
(
trainer
.
train
(
tf
.
constant
(
10
)),
10
)
self
.
assertEqual
(
trainer
.
evaluate
(
tf
.
constant
(
11
)),
11
)
def
test_async_trainer_train
(
self
):
if
TPU_TEST
or
GPU_TEST
:
self
.
skipTest
(
'Aysnc training is not available on GPU/TPU.'
)
num_workers
=
3
num_ps
=
2
cluster_resolver
=
create_in_process_cluster
(
num_workers
,
num_ps
)
distribution
=
tf
.
distribute
.
experimental
.
ParameterServerStrategy
(
cluster_resolver
)
with
distribution
.
scope
():
config
=
cfg
.
ExperimentConfig
(
**
self
.
_config
.
as_dict
())
config
.
trainer
.
eval_tf_while_loop
=
True
trainer
=
self
.
create_test_trainer
(
config
)
logs
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
logs
)
self
.
assertIn
(
'learning_rate'
,
logs
)
def
test_async_trainer_validate
(
self
):
if
TPU_TEST
or
GPU_TEST
:
self
.
skipTest
(
'Aysnc training is not available on GPU/GPU.'
)
num_workers
=
3
num_ps
=
2
cluster_resolver
=
create_in_process_cluster
(
num_workers
,
num_ps
)
distribution
=
tf
.
distribute
.
experimental
.
ParameterServerStrategy
(
cluster_resolver
)
with
distribution
.
scope
():
config
=
cfg
.
ExperimentConfig
(
**
self
.
_config
.
as_dict
())
config
.
trainer
.
eval_tf_while_loop
=
True
trainer
=
self
.
create_test_trainer
(
config
)
logs
=
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'acc'
,
logs
)
self
.
assertIn
(
'validation_loss'
,
logs
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_validate
(
self
,
distribution
):
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
self
.
_config
)
logs
=
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
logs
[
'counter'
],
5.
*
distribution
.
num_replicas_in_sync
)
self
.
assertIn
(
'validation_loss'
,
logs
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_validate_without_loss
(
self
,
distribution
):
class
MockTaskWithoutValidationLoss
(
mock_task
.
MockTask
):
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
# Disable validation loss.
logs
=
super
().
validation_step
(
inputs
,
model
)
del
logs
[
self
.
loss
]
return
logs
with
distribution
.
scope
():
task
=
MockTaskWithoutValidationLoss
()
trainer
=
self
.
create_test_trainer
(
self
.
_config
,
task
=
task
)
logs
=
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
logs
[
'counter'
],
5.
*
distribution
.
num_replicas_in_sync
)
self
.
assertNotIn
(
'validation_loss'
,
logs
)
@
combinations
.
generate
(
combinations
.
combine
(
mixed_precision_dtype
=
[
'float32'
,
'bfloat16'
,
'float16'
],
loss_scale
=
[
None
,
'dynamic'
,
128
,
256
],
))
def
test_configure_optimizer
(
self
,
mixed_precision_dtype
,
loss_scale
):
config
=
cfg
.
ExperimentConfig
(
runtime
=
cfg
.
RuntimeConfig
(
mixed_precision_dtype
=
mixed_precision_dtype
,
loss_scale
=
loss_scale
),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
cfg
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
},
'learning_rate'
:
{
'type'
:
'constant'
},
})))
trainer
=
self
.
create_test_trainer
(
config
)
if
mixed_precision_dtype
==
'float16'
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
)
if
loss_scale
in
(
None
,
'dynamic'
):
self
.
assertTrue
(
trainer
.
optimizer
.
dynamic
)
else
:
self
.
assertFalse
(
trainer
.
optimizer
.
dynamic
)
self
.
assertEqual
(
trainer
.
optimizer
.
initial_scale
,
loss_scale
)
else
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
(
tf
.
keras
.
optimizers
.
SGD
,
tf
.
keras
.
optimizers
.
legacy
.
SGD
))
metrics
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
metrics
)
def
test_export_best_ckpt
(
self
):
config
=
cfg
.
ExperimentConfig
(
trainer
=
cfg
.
TrainerConfig
(
best_checkpoint_export_subdir
=
'best_ckpt'
,
best_checkpoint_eval_metric
=
'acc'
,
optimizer_config
=
cfg
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
},
'learning_rate'
:
{
'type'
:
'constant'
}
})))
model_dir
=
self
.
get_temp_dir
()
trainer
=
self
.
create_test_trainer
(
config
,
model_dir
=
model_dir
)
trainer
.
train
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
model_dir
,
'best_ckpt'
,
'info.json'
)))
def
test_model_with_compiled_loss
(
self
):
task
=
mock_task
.
MockTask
()
model
=
task
.
build_model
()
model
.
compile
(
loss
=
tf
.
keras
.
losses
.
CategoricalCrossentropy
())
trainer
=
trainer_lib
.
Trainer
(
self
.
_config
,
task
,
model
=
model
,
optimizer
=
task
.
create_optimizer
(
self
.
_config
.
trainer
.
optimizer_config
))
logs
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
logs
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
models-2.13.1/official/core/config_definitions.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common configuration settings."""
import
dataclasses
from
typing
import
Optional
,
Sequence
,
Union
from
official.modeling.hyperparams
import
base_config
from
official.modeling.optimization.configs
import
optimization_config
from
official.modeling.privacy
import
configs
as
dp_configs
OptimizationConfig
=
optimization_config
.
OptimizationConfig
@
dataclasses
.
dataclass
class
DataConfig
(
base_config
.
Config
):
"""The base configuration for building datasets.
Attributes:
input_path: The path to the input. It can be either (1) a str indicating a
file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of
str, each of which is a file path/pattern or multiple file paths/patterns
separated by comma, or (4) a dictionary of the previous three approaches
for more advanced data mixing using named access. It should not be
specified when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
is required when above `tfds_name` is specified.
global_batch_size: The global batch size across all replicas.
is_training: Whether this data is used for training or not. This flag is
useful for consumers of this object to determine whether the data should
be repeated or shuffled.
drop_remainder: Whether the last batch should be dropped in the case it has
fewer than `global_batch_size` elements.
shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. If `True`, we will cache the
dataset after applying the decode_fn and parse_fn. It can be used to avoid
re-reading from disk, re-decoding and re-parsing the example on the second
epoch, but it requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
deterministic: A boolean controlling whether determinism should be enforced.
sharding: Whether sharding is used in the input pipeline.
enable_tf_data_service: A boolean indicating whether to enable tf.data
service for the input pipeline.
tf_data_service_address: The URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This argument
makes it possible for multiple datasets to share the same job. The default
behavior is that the dataset creates anonymous, exclusively owned jobs.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default, the
returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped for
decoding when loading dataset from TFDS. Use comma to separate multiple
features. The main use case is to skip the image/video decoding for better
performance.
enable_shared_tf_data_service_between_parallel_trainers: A bool. When set to
true, only a single tf.data service will be started, and it will be shared
between all the trainer run simultaneously, e.g. using vizier to tune
hyperparameters. This will save CPU and RAM resources compared to running
separate tf.data service for each trainer. Notice that if batch size is
different for different trainers, the field
apply_tf_data_service_before_batching also needs to be true so that only a
single tf.data service instance will be created. In this case, tf.data
service will be applied before batching operation. So make sure to not
apply any processing steps after batching (e.g. in postprocess_fn) since
they wouldn't be paralleled by tf.data service and may slow down your
tf.data pipeline. When using shared tf.data service, the tf.data dataset
must be infinite, and slow trainer may skip certain training examples.
More details about shared tf.data service can be found at:
https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers.
apply_tf_data_service_before_batching: A bool. If set to True, tf.data
service will be applied before batching operation. This is useful to make
sure only a single tf.data service instance is created when
enable_shared_tf_data_service_between_parallel_trainers is true and batch
size is changing between parallel trainers.
trainer_id: A string. The id of the trainer if there are multiple parallel
trainer running at the same time, e.g. in vizier tuning case. It will be
automatically set if this field is needed. Users does not need to set it
when creating experiment configs.
seed: An optional seed to use for deterministic shuffling/preprocessing.
prefetch_buffer_size: An int specifying the buffer size of prefetch
datasets. If None, the buffer size is autotuned. Specifying this is useful
in case autotuning uses up too much memory by making the buffer size too
high.
autotune_algorithm: If specified, use this algorithm for AUTOTUNE. See:
https://www.tensorflow.org/api_docs/python/tf/data/experimental/AutotuneAlgorithm
"""
input_path
:
Union
[
Sequence
[
str
],
str
,
base_config
.
Config
]
=
""
tfds_name
:
Union
[
str
,
base_config
.
Config
]
=
""
tfds_split
:
str
=
""
global_batch_size
:
int
=
0
is_training
:
Optional
[
bool
]
=
None
drop_remainder
:
bool
=
True
shuffle_buffer_size
:
int
=
100
cache
:
bool
=
False
cycle_length
:
Optional
[
int
]
=
None
block_length
:
int
=
1
deterministic
:
Optional
[
bool
]
=
None
sharding
:
bool
=
True
enable_tf_data_service
:
bool
=
False
tf_data_service_address
:
Optional
[
str
]
=
None
tf_data_service_job_name
:
Optional
[
str
]
=
None
tfds_data_dir
:
str
=
""
tfds_as_supervised
:
bool
=
False
tfds_skip_decoding_feature
:
str
=
""
enable_shared_tf_data_service_between_parallel_trainers
:
bool
=
False
apply_tf_data_service_before_batching
:
bool
=
False
trainer_id
:
Optional
[
str
]
=
None
seed
:
Optional
[
int
]
=
None
prefetch_buffer_size
:
Optional
[
int
]
=
None
autotune_algorithm
:
Optional
[
str
]
=
None
@
dataclasses
.
dataclass
class
RuntimeConfig
(
base_config
.
Config
):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
'float16', or 'bfloat16'.
loss_scale: The type of loss scale, or 'float' value. This is used when
setting the mixed precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
"""
distribution_strategy
:
str
=
"mirrored"
enable_xla
:
bool
=
False
gpu_thread_mode
:
Optional
[
str
]
=
None
dataset_num_private_threads
:
Optional
[
int
]
=
None
per_gpu_thread_count
:
int
=
0
tpu
:
Optional
[
str
]
=
None
num_gpus
:
int
=
0
worker_hosts
:
Optional
[
str
]
=
None
task_index
:
int
=
-
1
all_reduce_alg
:
Optional
[
str
]
=
None
num_packs
:
int
=
1
mixed_precision_dtype
:
Optional
[
str
]
=
None
loss_scale
:
Optional
[
Union
[
str
,
float
]]
=
None
run_eagerly
:
bool
=
False
batchnorm_spatial_persistent
:
bool
=
False
# XLA runtime params.
# XLA params are only applied to the train_step.
# These augments can improve training speed. They can also improve eval, but
# may reduce usability and users would need to make changes to code.
# Whether to enable XLA dynamic padder
# infrastructure to handle dynamic shapes inputs inside XLA. True by
# default. Disabling this may cause correctness issues with dynamic shapes
# inputs, as XLA will just assume the inputs are with padded shapes. However
# users can optionally set it to False to improve device time if masking is
# already handled in the user side.
# If None, will respect XLA default.
tpu_enable_xla_dynamic_padder
:
Optional
[
bool
]
=
None
# Global model parallelism configurations.
num_cores_per_replica
:
int
=
1
default_shard_dim
:
int
=
-
1
use_tpu_mp_strategy
:
bool
=
False
def
model_parallelism
(
self
):
return
dict
(
num_cores_per_replica
=
self
.
num_cores_per_replica
,
default_shard_dim
=
self
.
default_shard_dim
)
@
dataclasses
.
dataclass
class
TrainerConfig
(
base_config
.
Config
):
"""Configuration for trainer.
Attributes:
optimizer_config: optimizer config, it includes optimizer, learning rate,
and warmup schedule configs.
train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval.
allow_tpu_summary: Whether to allow summary happen inside the XLA program
runs on TPU through automatic outside compilation.
steps_per_loop: number of steps per loop to report training metrics. This
can also be used to reduce host worker communication in a TPU setup.
summary_interval: number of steps between each summary.
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely. This
is only used continuous_train_and_eval and continuous_eval modes. Default
value is 1 hrs.
train_steps: number of train steps.
validation_steps: number of eval steps. If -1, the entire eval dataset is
used.
validation_interval: number of training steps to run between evaluations.
best_checkpoint_export_subdir: if set, the trainer will keep track of the
best evaluation metric, and export the corresponding best checkpoint under
`model_dir/best_checkpoint_export_subdir`. Note that this only works if
mode contains eval (such as `train_and_eval`, `continuous_eval`, and
`continuous_train_and_eval`).
best_checkpoint_eval_metric: for exporting the best checkpoint, which
evaluation metric the trainer should monitor. This can be any evaluation
metric appears on tensorboard.
best_checkpoint_metric_comp: for exporting the best checkpoint, how the
trainer should compare the evaluation metrics. This can be either `higher`
(higher the better) or `lower` (lower the better).
validation_summary_subdir: A 'str', sub directory for saving eval summary.
preemption_on_demand_checkpoint: whether or not to save on-demand
checkpoints after a preemption.
"""
optimizer_config
:
OptimizationConfig
=
OptimizationConfig
()
# Orbit settings.
train_tf_while_loop
:
bool
=
True
train_tf_function
:
bool
=
True
eval_tf_function
:
bool
=
True
eval_tf_while_loop
:
bool
=
False
allow_tpu_summary
:
bool
=
False
# Trainer intervals.
steps_per_loop
:
int
=
1000
summary_interval
:
int
=
1000
checkpoint_interval
:
int
=
1000
# Checkpoint manager.
max_to_keep
:
int
=
5
continuous_eval_timeout
:
int
=
60
*
60
# Train/Eval routines.
train_steps
:
int
=
0
# Sets validation steps to be -1 to evaluate the entire dataset.
validation_steps
:
int
=
-
1
validation_interval
:
int
=
1000
# Best checkpoint export.
best_checkpoint_export_subdir
:
str
=
""
best_checkpoint_eval_metric
:
str
=
""
best_checkpoint_metric_comp
:
str
=
"higher"
# Blowup recovery.
loss_upper_bound
:
float
=
1e6
recovery_begin_steps
:
int
=
0
# Enforcing the loss bound after these steps.
# When max trials < 0, no recovery module; max trials = 0, we will check
# the condition and fail the job if the condition happens; max trials > 0,
# we will retore the model states.
recovery_max_trials
:
int
=
0
validation_summary_subdir
:
str
=
"validation"
# Preemption on-demand checkpoint.
preemption_on_demand_checkpoint
:
bool
=
True
@
dataclasses
.
dataclass
class
TaskConfig
(
base_config
.
Config
):
"""Config passed to task."""
init_checkpoint
:
str
=
""
model
:
Optional
[
base_config
.
Config
]
=
None
train_data
:
DataConfig
=
DataConfig
()
validation_data
:
DataConfig
=
DataConfig
()
name
:
Optional
[
str
]
=
None
# Configs for differential privacy
# These configs are only effective if you use create_optimizer in
# tensorflow_models/official/core/base_task.py
# DEPRECATED b/264611883
differential_privacy_config
:
Optional
[
dp_configs
.
DifferentialPrivacyConfig
]
=
None
# Whether to show image summary. Useful to visualize model predictions. Only
# work for vision tasks.
allow_image_summary
:
bool
=
False
@
dataclasses
.
dataclass
class
ExperimentConfig
(
base_config
.
Config
):
"""Top-level configuration."""
task
:
TaskConfig
=
TaskConfig
()
trainer
:
TrainerConfig
=
TrainerConfig
()
runtime
:
RuntimeConfig
=
RuntimeConfig
()
models-2.13.1/official/core/exp_factory.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experiment factory methods."""
from
official.core
import
config_definitions
as
cfg
from
official.core
import
registry
_REGISTERED_CONFIGS
=
{}
def
register_config_factory
(
name
):
"""Register ExperimentConfig factory method."""
return
registry
.
register
(
_REGISTERED_CONFIGS
,
name
)
def
get_exp_config
(
exp_name
:
str
)
->
cfg
.
ExperimentConfig
:
"""Looks up the `ExperimentConfig` according to the `exp_name`."""
exp_creater
=
registry
.
lookup
(
_REGISTERED_CONFIGS
,
exp_name
)
return
exp_creater
()
models-2.13.1/official/core/export_base.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for model export."""
import
abc
import
functools
import
time
from
typing
import
Any
,
Callable
,
Dict
,
Mapping
,
List
,
Optional
,
Text
,
Union
from
absl
import
logging
import
tensorflow
as
tf
MAX_DIRECTORY_CREATION_ATTEMPTS
=
10
class
ExportModule
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""Base Export Module."""
def
__init__
(
self
,
params
,
model
:
Union
[
tf
.
Module
,
tf
.
keras
.
Model
],
inference_step
:
Optional
[
Callable
[...,
Any
]]
=
None
,
*
,
preprocessor
:
Optional
[
Callable
[...,
Any
]]
=
None
,
postprocessor
:
Optional
[
Callable
[...,
Any
]]
=
None
):
"""Instantiates an ExportModel.
Examples:
`inference_step` must be a function that has `model` as an kwarg or the
second positional argument.
```
def _inference_step(inputs, model=None):
return model(inputs, training=False)
module = ExportModule(params, model, inference_step=_inference_step)
```
`preprocessor` and `postprocessor` could be either functions or `tf.Module`.
The usages of preprocessor and postprocessor are managed by the
implementation of `serve()` method.
Args:
params: A dataclass for parameters to the module.
model: A model instance which contains weights and forward computation.
inference_step: An optional callable to forward-pass the model. If not
specified, it creates a parital function with `model` as an required
kwarg.
preprocessor: An optional callable to preprocess the inputs.
postprocessor: An optional callable to postprocess the model outputs.
"""
super
().
__init__
(
name
=
None
)
self
.
model
=
model
self
.
params
=
params
if
inference_step
is
not
None
:
self
.
inference_step
=
functools
.
partial
(
inference_step
,
model
=
self
.
model
)
else
:
if
issubclass
(
type
(
model
),
tf
.
keras
.
Model
):
# Default to self.model.call instead of self.model.__call__ to avoid
# keras tracing logic designed for training.
# Since most of Model Garden's call doesn't not have training kwargs
# or the default is False, we don't pass anything here.
# Please pass custom inference step if your model has training=True as
# default.
self
.
inference_step
=
self
.
model
.
call
else
:
self
.
inference_step
=
functools
.
partial
(
self
.
model
.
__call__
,
training
=
False
)
self
.
preprocessor
=
preprocessor
self
.
postprocessor
=
postprocessor
@
abc
.
abstractmethod
def
serve
(
self
)
->
Mapping
[
Text
,
tf
.
Tensor
]:
"""The bare inference function which should run on all devices.
Expecting tensors are passed in through keyword arguments. Returns a
dictionary of tensors, when the keys will be used inside the SignatureDef.
"""
@
abc
.
abstractmethod
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
])
->
Mapping
[
Text
,
Any
]:
"""Get defined function signatures."""
def
export
(
export_module
:
ExportModule
,
function_keys
:
Union
[
List
[
Text
],
Dict
[
Text
,
Text
]],
export_savedmodel_dir
:
Text
,
checkpoint_path
:
Optional
[
Text
]
=
None
,
timestamped
:
bool
=
True
,
save_options
:
Optional
[
tf
.
saved_model
.
SaveOptions
]
=
None
,
checkpoint
:
Optional
[
tf
.
train
.
Checkpoint
]
=
None
)
->
Text
:
"""Exports to SavedModel format.
Args:
export_module: a ExportModule with the keras Model and serving tf.functions.
function_keys: a list of string keys to retrieve pre-defined serving
signatures. The signaute keys will be set with defaults. If a dictionary
is provided, the values will be used as signature keys.
export_savedmodel_dir: Output saved model directory.
checkpoint_path: Object-based checkpoint path or directory.
timestamped: Whether to export the savedmodel to a timestamped directory.
save_options: `SaveOptions` for `tf.saved_model.save`.
checkpoint: An optional tf.train.Checkpoint. If provided, the export module
will use it to read the weights.
Returns:
The savedmodel directory path.
"""
ckpt_dir_or_file
=
checkpoint_path
if
ckpt_dir_or_file
is
not
None
and
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
ckpt_dir_or_file
:
if
checkpoint
is
None
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
export_module
.
model
)
checkpoint
.
read
(
ckpt_dir_or_file
).
assert_existing_objects_matched
().
expect_partial
()
if
isinstance
(
function_keys
,
list
):
if
len
(
function_keys
)
==
1
:
function_keys
=
{
function_keys
[
0
]:
tf
.
saved_model
.
DEFAULT_SERVING_SIGNATURE_DEF_KEY
}
else
:
raise
ValueError
(
'If the function_keys is a list, it must contain a single element. %s'
%
function_keys
)
signatures
=
export_module
.
get_inference_signatures
(
function_keys
)
if
timestamped
:
export_dir
=
get_timestamped_export_dir
(
export_savedmodel_dir
).
decode
(
'utf-8'
)
else
:
export_dir
=
export_savedmodel_dir
tf
.
saved_model
.
save
(
export_module
,
export_dir
,
signatures
=
signatures
,
options
=
save_options
)
return
export_dir
def
get_timestamped_export_dir
(
export_dir_base
):
"""Builds a path to a new subdirectory within the base directory.
Args:
export_dir_base: A string containing a directory to write the exported graph
and checkpoints.
Returns:
The full path of the new subdirectory (which is not actually created yet).
Raises:
RuntimeError: if repeated attempts fail to obtain a unique timestamped
directory name.
"""
attempts
=
0
while
attempts
<
MAX_DIRECTORY_CREATION_ATTEMPTS
:
timestamp
=
int
(
time
.
time
())
result_dir
=
tf
.
io
.
gfile
.
join
(
tf
.
compat
.
as_bytes
(
export_dir_base
),
tf
.
compat
.
as_bytes
(
str
(
timestamp
)))
if
not
tf
.
io
.
gfile
.
exists
(
result_dir
):
# Collisions are still possible (though extremely unlikely): this
# directory is not actually created yet, but it will be almost
# instantly on return from this function.
return
result_dir
time
.
sleep
(
1
)
attempts
+=
1
logging
.
warning
(
'Directory %s already exists; retrying (attempt %s/%s)'
,
str
(
result_dir
),
attempts
,
MAX_DIRECTORY_CREATION_ATTEMPTS
)
raise
RuntimeError
(
'Failed to obtain a unique export directory name after '
f
'
{
MAX_DIRECTORY_CREATION_ATTEMPTS
}
attempts.'
)
models-2.13.1/official/core/export_base_test.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.export_base."""
import
os
from
typing
import
Any
,
Dict
,
Mapping
,
Text
import
tensorflow
as
tf
from
official.core
import
export_base
class
TestModule
(
export_base
.
ExportModule
):
@
tf
.
function
def
serve
(
self
,
inputs
:
tf
.
Tensor
)
->
Mapping
[
Text
,
tf
.
Tensor
]:
x
=
inputs
if
self
.
preprocessor
is
None
else
self
.
preprocessor
(
inputs
=
inputs
)
x
=
self
.
inference_step
(
x
)
x
=
self
.
postprocessor
(
x
)
if
self
.
postprocessor
else
x
return
{
'outputs'
:
x
}
def
get_inference_signatures
(
self
,
function_keys
:
Dict
[
Text
,
Text
])
->
Mapping
[
Text
,
Any
]:
input_signature
=
tf
.
TensorSpec
(
shape
=
[
None
,
None
],
dtype
=
tf
.
float32
)
return
{
'foo'
:
self
.
serve
.
get_concrete_function
(
input_signature
)}
class
ExportBaseTest
(
tf
.
test
.
TestCase
):
def
test_export_module
(
self
):
tmp_dir
=
self
.
get_temp_dir
()
model
=
tf
.
keras
.
layers
.
Dense
(
2
)
inputs
=
tf
.
ones
([
2
,
4
],
tf
.
float32
)
expected_output
=
model
(
inputs
,
training
=
False
)
module
=
TestModule
(
params
=
None
,
model
=
model
)
ckpt_path
=
tf
.
train
.
Checkpoint
(
model
=
model
).
save
(
os
.
path
.
join
(
tmp_dir
,
'ckpt'
))
export_dir
=
export_base
.
export
(
module
,
[
'foo'
],
export_savedmodel_dir
=
tmp_dir
,
checkpoint_path
=
ckpt_path
,
timestamped
=
True
)
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
export_dir
,
'saved_model.pb'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
export_dir
,
'variables'
,
'variables.index'
)))
self
.
assertTrue
(
os
.
path
.
exists
(
os
.
path
.
join
(
export_dir
,
'variables'
,
'variables.data-00000-of-00001'
)))
imported
=
tf
.
saved_model
.
load
(
export_dir
)
output
=
imported
.
signatures
[
'foo'
](
inputs
)
self
.
assertAllClose
(
output
[
'outputs'
].
numpy
(),
expected_output
.
numpy
())
def
test_custom_inference_step
(
self
):
tmp_dir
=
self
.
get_temp_dir
()
model
=
tf
.
keras
.
layers
.
Dense
(
2
)
inputs
=
tf
.
ones
([
2
,
4
],
tf
.
float32
)
def
_inference_step
(
inputs
,
model
):
return
tf
.
nn
.
softmax
(
model
(
inputs
,
training
=
False
))
module
=
TestModule
(
params
=
None
,
model
=
model
,
inference_step
=
_inference_step
)
expected_output
=
_inference_step
(
inputs
,
model
)
ckpt_path
=
tf
.
train
.
Checkpoint
(
model
=
model
).
save
(
os
.
path
.
join
(
tmp_dir
,
'ckpt'
))
export_dir
=
export_base
.
export
(
module
,
[
'foo'
],
export_savedmodel_dir
=
tmp_dir
,
checkpoint_path
=
ckpt_path
,
timestamped
=
False
)
imported
=
tf
.
saved_model
.
load
(
export_dir
)
output
=
imported
.
signatures
[
'foo'
](
inputs
)
self
.
assertAllClose
(
output
[
'outputs'
].
numpy
(),
expected_output
.
numpy
())
def
test_processors
(
self
):
model
=
tf
.
Module
()
inputs
=
tf
.
zeros
((),
tf
.
float32
)
def
_inference_step
(
inputs
,
model
):
del
model
return
inputs
+
1.0
def
_preprocessor
(
inputs
):
print
(
inputs
)
return
inputs
+
0.1
module
=
TestModule
(
params
=
None
,
model
=
model
,
inference_step
=
_inference_step
,
preprocessor
=
_preprocessor
)
output
=
module
.
serve
(
inputs
)
self
.
assertAllClose
(
output
[
'outputs'
].
numpy
(),
1.1
)
class
_PostProcessor
(
tf
.
Module
):
def
__call__
(
self
,
inputs
):
return
inputs
+
0.01
module
=
TestModule
(
params
=
None
,
model
=
model
,
inference_step
=
_inference_step
,
preprocessor
=
_preprocessor
,
postprocessor
=
_PostProcessor
())
output
=
module
.
serve
(
inputs
)
self
.
assertAllClose
(
output
[
'outputs'
].
numpy
(),
1.11
)
def
test_get_timestamped_export_dir
(
self
):
export_dir
=
self
.
get_temp_dir
()
timed_dir
=
export_base
.
get_timestamped_export_dir
(
export_dir_base
=
export_dir
)
self
.
assertFalse
(
tf
.
io
.
gfile
.
exists
(
timed_dir
))
self
.
assertIn
(
export_dir
,
str
(
timed_dir
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
models-2.13.1/official/core/file_writers.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""File writer functions for dataset preparation, infra validation, and unit tests."""
import
io
from
typing
import
Optional
,
Sequence
,
Union
import
tensorflow
as
tf
def
write_small_dataset
(
examples
:
Sequence
[
Union
[
tf
.
train
.
Example
,
tf
.
train
.
SequenceExample
]],
output_path
:
str
,
file_type
:
str
=
'tfrecord'
)
->
None
:
"""Writes `examples` to a file at `output_path` with type `file_type`.
CAVEAT: This function is not recommended for writing large datasets, since it
will loop through `examples` and perform write operation sequentially.
Args:
examples: List of tf.train.Example or tf.train.SequenceExample.
output_path: Output path for the dataset.
file_type: A string indicating the file format, could be: 'tfrecord',
'tfrecords', 'tfrecord_compressed', 'tfrecords_gzip', 'riegeli'. The
string is case insensitive.
"""
file_type
=
file_type
.
lower
()
if
file_type
==
'tfrecord'
or
file_type
==
'tfrecords'
:
_write_tfrecord
(
examples
,
output_path
)
elif
file_type
==
'tfrecord_compressed'
or
file_type
==
'tfrecords_gzip'
:
_write_tfrecord
(
examples
,
output_path
,
tf
.
io
.
TFRecordOptions
(
compression_type
=
'GZIP'
))
elif
file_type
==
'riegeli'
:
_write_riegeli
(
examples
,
output_path
)
else
:
raise
ValueError
(
f
'Unknown file_type:
{
file_type
}
'
)
def
_write_tfrecord
(
examples
:
Sequence
[
Union
[
tf
.
train
.
Example
,
tf
.
train
.
SequenceExample
]],
output_path
:
str
,
options
:
Optional
[
tf
.
io
.
TFRecordOptions
]
=
None
)
->
None
:
"""Writes `examples` to a TFRecord file at `output_path`.
Args:
examples: A list of tf.train.Example.
output_path: Output path for the dataset.
options: Options used for manipulating TFRecord files.
"""
with
tf
.
io
.
TFRecordWriter
(
output_path
,
options
)
as
writer
:
for
example
in
examples
:
writer
.
write
(
example
.
SerializeToString
())
def
_write_riegeli
(
examples
:
Sequence
[
Union
[
tf
.
train
.
Example
,
tf
.
train
.
SequenceExample
]],
output_path
:
str
)
->
None
:
"""Writes `examples` to a Riegeli file at `output_path`.
Args:
examples: A list of tf.train.Example.
output_path: Output path for the dataset.
"""
with
io
.
FileIO
(
output_path
,
'wb'
)
as
fileio
:
import
riegeli
# pylint: disable=g-import-not-at-top
with
riegeli
.
RecordWriter
(
fileio
)
as
writer
:
writer
.
write_messages
(
examples
)
models-2.13.1/official/core/file_writers_test.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for file_writers."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
file_writers
from
official.core
import
tf_example_builder
class
FileWritersTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
example_builder
=
tf_example_builder
.
TfExampleBuilder
()
example_builder
.
add_bytes_feature
(
'foo'
,
'Hello World!'
)
self
.
_example
=
example_builder
.
example
@
parameterized
.
parameters
(
'tfrecord'
,
'TFRecord'
,
'tfrecords'
,
'tfrecord_compressed'
,
'TFRecord_Compressed'
,
'tfrecords_gzip'
)
def
test_write_small_dataset_success
(
self
,
file_type
):
temp_dir
=
self
.
create_tempdir
()
temp_dataset_file
=
os
.
path
.
join
(
temp_dir
.
full_path
,
'train'
)
file_writers
.
write_small_dataset
([
self
.
_example
],
temp_dataset_file
,
file_type
)
self
.
assertTrue
(
os
.
path
.
exists
(
temp_dataset_file
))
def
test_write_small_dataset_unrecognized_format
(
self
):
file_type
=
'bar'
temp_dir
=
self
.
create_tempdir
()
temp_dataset_file
=
os
.
path
.
join
(
temp_dir
.
full_path
,
'train'
)
with
self
.
assertRaises
(
ValueError
):
file_writers
.
write_small_dataset
([
self
.
_example
],
temp_dataset_file
,
file_type
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
models-2.13.1/official/core/input_reader.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A common dataset reader."""
import
dataclasses
import
random
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Text
,
Union
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
from
official.core
import
config_definitions
as
cfg
def
_get_random_integer
():
return
random
.
randint
(
0
,
(
1
<<
31
)
-
1
)
def
_maybe_map_fn
(
dataset
:
tf
.
data
.
Dataset
,
fn
:
Optional
[
Callable
[...,
Any
]]
=
None
)
->
tf
.
data
.
Dataset
:
"""Calls dataset.map if a valid function is passed in."""
return
dataset
if
fn
is
None
else
dataset
.
map
(
fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
def
match_files
(
input_path
:
Union
[
Sequence
[
str
],
str
])
->
List
[
str
]:
"""Matches files from an input_path."""
matched_files
=
[]
# Read dataset from files.
usage
=
(
'`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s'
)
if
isinstance
(
input_path
,
str
):
input_path_list
=
[
input_path
]
elif
isinstance
(
input_path
,
(
list
,
tuple
)):
if
any
(
not
isinstance
(
x
,
str
)
for
x
in
input_path
):
raise
ValueError
(
usage
%
input_path
)
input_path_list
=
input_path
else
:
raise
ValueError
(
usage
%
input_path
)
for
input_path
in
input_path_list
:
input_patterns
=
input_path
.
strip
().
split
(
','
)
for
input_pattern
in
input_patterns
:
input_pattern
=
input_pattern
.
strip
()
if
not
input_pattern
:
continue
if
'*'
in
input_pattern
or
'?'
in
input_pattern
:
tmp_matched_files
=
tf
.
io
.
gfile
.
glob
(
input_pattern
)
if
not
tmp_matched_files
:
raise
ValueError
(
'%s does not match any files.'
%
input_pattern
)
matched_files
.
extend
(
tmp_matched_files
)
else
:
matched_files
.
append
(
input_pattern
)
if
not
matched_files
:
raise
ValueError
(
'%s does not match any files.'
%
input_path
)
return
matched_files
def
_read_files_then_shard
(
matched_files
:
List
[
str
],
dataset_fn
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
sharding
:
bool
=
False
,
repeat
:
bool
=
False
)
->
tf
.
data
.
Dataset
:
"""Sends all data files to every worker and then shard by data."""
dataset
=
dataset_fn
(
matched_files
)
# When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding
# so that same input file is sent to all workers.
options
=
tf
.
data
.
Options
()
options
.
experimental_distribute
.
auto_shard_policy
=
(
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
)
dataset
=
dataset
.
with_options
(
options
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
if
repeat
:
dataset
=
dataset
.
repeat
()
return
dataset
def
_shard_files_then_read
(
matched_files
:
List
[
str
],
dataset_fn
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
seed
:
Optional
[
Union
[
int
,
tf
.
Tensor
]]
=
None
,
is_training
:
bool
=
False
,
sharding
:
bool
=
False
,
cache
:
bool
=
False
,
cycle_length
:
Optional
[
int
]
=
None
,
block_length
:
Optional
[
int
]
=
None
,
deterministic
:
bool
=
False
)
->
tf
.
data
.
Dataset
:
"""Shards the data files and then sent a split to every worker to read."""
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
matched_files
)
# Shuffle and repeat at file level.
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway.
if
is_training
:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if
sharding
and
seed
is
None
:
seed
=
_get_random_integer
()
dataset
=
dataset
.
shuffle
(
len
(
matched_files
),
seed
=
seed
,
reshuffle_each_iteration
=
True
if
not
cache
else
False
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if
is_training
and
not
cache
:
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
interleave
(
map_func
=
dataset_fn
,
cycle_length
=
cycle_length
,
block_length
=
block_length
,
num_parallel_calls
=
(
cycle_length
if
cycle_length
else
tf
.
data
.
experimental
.
AUTOTUNE
),
deterministic
=
deterministic
)
return
dataset
def
_read_tfds
(
tfds_name
:
Text
,
tfds_data_dir
:
Text
,
tfds_split
:
Text
,
tfds_skip_decoding_feature
:
Text
,
tfds_as_supervised
:
bool
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
seed
:
Optional
[
Union
[
int
,
tf
.
Tensor
]]
=
None
,
is_training
:
bool
=
False
,
cache
:
bool
=
False
,
cycle_length
:
Optional
[
int
]
=
None
,
block_length
:
Optional
[
int
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Reads a dataset from tfds."""
repeat_filenames
=
is_training
and
not
cache
read_config
=
tfds
.
ReadConfig
(
interleave_cycle_length
=
cycle_length
,
interleave_block_length
=
block_length
,
input_context
=
input_context
,
shuffle_seed
=
seed
,
repeat_filenames
=
repeat_filenames
,
# Only assert cardinality when we have a finite dataset.
assert_cardinality
=
not
repeat_filenames
,
skip_prefetch
=
True
)
decoders
=
{}
if
tfds_skip_decoding_feature
:
for
skip_feature
in
tfds_skip_decoding_feature
.
split
(
','
):
decoders
[
skip_feature
.
strip
()]
=
tfds
.
decode
.
SkipDecoding
()
if
tfds_name
.
startswith
(
'mldataset.'
):
dataset
=
tfds
.
load
(
name
=
tfds_name
,
split
=
tfds_split
,
as_supervised
=
tfds_as_supervised
,
decoders
=
decoders
if
decoders
else
None
,
read_config
=
read_config
)
else
:
builder
=
tfds
.
builder
(
tfds_name
,
data_dir
=
tfds_data_dir
)
if
builder
.
info
.
splits
:
num_shards
=
len
(
builder
.
info
.
splits
[
tfds_split
].
file_instructions
)
else
:
# The tfds mock path often does not provide splits.
num_shards
=
1
load_kwargs
=
dict
(
name
=
tfds_name
,
download
=
True
,
split
=
tfds_split
,
shuffle_files
=
is_training
,
as_supervised
=
tfds_as_supervised
,
decoders
=
decoders
if
decoders
else
None
)
if
tfds_data_dir
:
load_kwargs
.
update
({
'data_dir'
:
tfds_data_dir
})
if
input_context
and
num_shards
<
input_context
.
num_input_pipelines
:
# The number of files in the dataset split is smaller than the number of
# input pipelines. We read the entire dataset first and then shard in the
# host memory.
read_config
=
dataclasses
.
replace
(
read_config
,
input_context
=
None
)
load_kwargs
.
update
({
'read_config'
:
read_config
})
dataset
=
tfds
.
load
(
**
load_kwargs
)
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
else
:
load_kwargs
.
update
({
'read_config'
:
read_config
})
dataset
=
tfds
.
load
(
**
load_kwargs
)
return
dataset
class
InputReader
:
"""Input reader that returns a tf.data.Dataset instance."""
# A static random number which is the same across different InputReader
# instances.
static_randnum
=
_get_random_integer
()
def
__init__
(
self
,
params
:
cfg
.
DataConfig
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
decoder_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
combine_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
sample_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
parser_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
filter_fn
:
Optional
[
Callable
[...,
tf
.
Tensor
]]
=
None
,
transform_and_batch_fn
:
Optional
[
Callable
[
[
tf
.
data
.
Dataset
,
Optional
[
tf
.
distribute
.
InputContext
]],
tf
.
data
.
Dataset
,
]
]
=
None
,
postprocess_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
):
"""Initializes an InputReader instance.
Args:
params: A config_definitions.DataConfig object.
dataset_fn: A `tf.data.Dataset` that consumes the input files. For
example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary.
combine_fn: An optional `callable` that takes a dictionarty of
`tf.data.Dataset` objects as input and outputs a combined dataset. It
will be executed after the decoder_fn and before the sample_fn.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn.
parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn.
filter_fn: An optional `callable` mapping a dataset element to a boolean.
It will be executed after parser_fn.
transform_and_batch_fn: An optional `callable` that takes a
`tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
input, and returns a `tf.data.Dataset` object. It will be executed after
`parser_fn` to transform and batch the dataset; if None, after
`parser_fn` is executed, the dataset will be batched into per-replica
batch size.
postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching.
"""
if
params
.
input_path
and
params
.
tfds_name
:
raise
ValueError
(
'At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.'
%
(
params
.
input_path
,
params
.
tfds_name
))
if
(
isinstance
(
params
.
input_path
,
cfg
.
base_config
.
Config
)
or
isinstance
(
params
.
tfds_name
,
cfg
.
base_config
.
Config
)
)
and
combine_fn
is
None
:
raise
ValueError
(
'A combine_fn is required if `input_path` or `tfds_name` is a dict.'
)
self
.
_tfds_name
=
params
.
tfds_name
self
.
_tfds_data_dir
=
params
.
tfds_data_dir
self
.
_matched_files
=
None
if
not
params
.
input_path
:
# Read dataset from TFDS.
if
not
params
.
tfds_split
:
raise
ValueError
(
'`tfds_name` is %s, but `tfds_split` is not specified.'
%
params
.
tfds_name
)
else
:
self
.
_matched_files
=
self
.
get_files
(
params
.
input_path
)
self
.
_global_batch_size
=
params
.
global_batch_size
self
.
_is_training
=
params
.
is_training
self
.
_drop_remainder
=
params
.
drop_remainder
self
.
_shuffle_buffer_size
=
params
.
shuffle_buffer_size
self
.
_cache
=
params
.
cache
self
.
_cycle_length
=
params
.
cycle_length
self
.
_block_length
=
params
.
block_length
self
.
_deterministic
=
params
.
deterministic
self
.
_sharding
=
params
.
sharding
self
.
_tfds_split
=
params
.
tfds_split
self
.
_tfds_as_supervised
=
params
.
tfds_as_supervised
self
.
_tfds_skip_decoding_feature
=
params
.
tfds_skip_decoding_feature
self
.
_dataset_fn
=
dataset_fn
self
.
_decoder_fn
=
decoder_fn
self
.
_combine_fn
=
combine_fn
self
.
_sample_fn
=
sample_fn
self
.
_parser_fn
=
parser_fn
self
.
_transform_and_batch_fn
=
transform_and_batch_fn
self
.
_postprocess_fn
=
postprocess_fn
self
.
_filter_fn
=
filter_fn
self
.
_seed
=
params
.
seed
self
.
_prefetch_buffer_size
=
(
params
.
prefetch_buffer_size
or
tf
.
data
.
experimental
.
AUTOTUNE
)
self
.
_autotune_algorithm
=
params
.
autotune_algorithm
# When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None.
# Sharding should also be disabled because tf data service handles how
# each worker shard data with `processing_mode` in distribute method.
if
params
.
enable_tf_data_service
:
self
.
_seed
=
None
self
.
_sharding
=
False
self
.
_enable_tf_data_service
=
(
params
.
enable_tf_data_service
and
params
.
tf_data_service_address
)
self
.
_tf_data_service_address
=
params
.
tf_data_service_address
self
.
_enable_shared_tf_data_service_between_parallel_trainers
=
(
params
.
enable_shared_tf_data_service_between_parallel_trainers
)
self
.
_apply_tf_data_service_before_batching
=
(
params
.
apply_tf_data_service_before_batching
)
self
.
_trainer_id
=
params
.
trainer_id
if
self
.
_enable_tf_data_service
:
# Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted.
# It's necessary to add global batch size into the tf data service job
# name because when tuning batch size with vizier and tf data service is
# also enable, the tf data servce job name should be different for
# different vizier trials since once batch size is changed, from the
# tf.data perspective, the dataset is a different instance, and a
# different job name should be used for tf data service. Otherwise, the
# model would read tensors from the incorrect tf data service job, which
# would causes dimension mismatch on the batch size dimension.
self
.
_tf_data_service_job_name
=
(
f
'
{
params
.
tf_data_service_job_name
}
_bs
{
params
.
global_batch_size
}
_'
f
'
{
self
.
static_randnum
}
'
)
self
.
_enable_round_robin_tf_data_service
=
params
.
get
(
'enable_round_robin_tf_data_service'
,
False
)
if
self
.
_enable_shared_tf_data_service_between_parallel_trainers
:
# When shared tf.data service is enabled, only a single tf.data service
# instance should be created and shared between parallel trainers. If
# the global batch size is different across trainers,
# params.apply_tf_data_service_before_batching should be set to true
# because tf.data service with different batch sizes will be considered
# separate tf.data service instances.
self
.
_tf_data_service_job_name
=
(
f
'
{
params
.
tf_data_service_job_name
}
_
{
self
.
static_randnum
}
'
)
def
get_files
(
self
,
input_path
):
"""Gets matched files. Can be overridden by subclasses."""
if
not
input_path
:
return
None
# we want to combine / mix datasets
if
isinstance
(
input_path
,
cfg
.
base_config
.
Config
):
matched_files
=
{}
for
k
,
v
in
input_path
.
as_dict
().
items
():
matched_files
[
k
]
=
match_files
(
v
)
# single dataset
else
:
matched_files
=
match_files
(
input_path
)
return
matched_files
def
_read_data_source
(
self
,
matched_files
:
Union
[
Dict
[
str
,
List
[
str
]],
List
[
str
]],
dataset_fn
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
):
"""Reads the data source (files/tfds) to a dataset."""
def
_files_to_dataset
(
files
:
List
[
str
])
->
tf
.
data
.
Dataset
:
if
len
(
files
)
>
1
:
if
input_context
and
(
len
(
files
)
<
input_context
.
num_input_pipelines
):
logging
.
warn
(
(
'The number of files %d is less than the number of input '
'pipelines %d. We will send all input files to every worker. '
'Please consider sharding your data into more files.'
),
len
(
files
),
input_context
.
num_input_pipelines
,
)
return
_read_files_then_shard
(
files
,
dataset_fn
,
input_context
,
sharding
=
self
.
_sharding
,
repeat
=
self
.
_is_training
and
not
self
.
_cache
)
else
:
return
_shard_files_then_read
(
files
,
dataset_fn
,
input_context
,
seed
=
self
.
_seed
,
is_training
=
self
.
_is_training
,
sharding
=
self
.
_sharding
,
cache
=
self
.
_cache
,
cycle_length
=
self
.
_cycle_length
,
block_length
=
self
.
_block_length
,
deterministic
=
self
.
_deterministic
)
elif
len
(
files
)
==
1
:
return
_read_files_then_shard
(
files
,
dataset_fn
,
input_context
,
sharding
=
self
.
_sharding
,
repeat
=
self
.
_is_training
and
not
self
.
_cache
)
else
:
raise
ValueError
(
'It is unexpected that `tfds_builder` is None and '
'there is also no `files`.'
)
if
self
.
_tfds_name
:
if
isinstance
(
self
.
_tfds_name
,
cfg
.
base_config
.
Config
):
dataset
=
{}
for
k
,
tfds_name
in
self
.
_tfds_name
.
as_dict
().
items
():
dataset
[
k
]
=
_read_tfds
(
tfds_name
=
tfds_name
,
tfds_data_dir
=
self
.
_tfds_data_dir
,
tfds_split
=
self
.
_tfds_split
,
tfds_skip_decoding_feature
=
self
.
_tfds_skip_decoding_feature
,
tfds_as_supervised
=
self
.
_tfds_as_supervised
,
input_context
=
input_context
,
seed
=
self
.
_seed
,
is_training
=
self
.
_is_training
,
cache
=
self
.
_cache
,
cycle_length
=
self
.
_cycle_length
,
block_length
=
self
.
_block_length
)
else
:
dataset
=
_read_tfds
(
tfds_name
=
self
.
_tfds_name
,
tfds_data_dir
=
self
.
_tfds_data_dir
,
tfds_split
=
self
.
_tfds_split
,
tfds_skip_decoding_feature
=
self
.
_tfds_skip_decoding_feature
,
tfds_as_supervised
=
self
.
_tfds_as_supervised
,
input_context
=
input_context
,
seed
=
self
.
_seed
,
is_training
=
self
.
_is_training
,
cache
=
self
.
_cache
,
cycle_length
=
self
.
_cycle_length
,
block_length
=
self
.
_block_length
)
elif
isinstance
(
matched_files
,
(
list
,
tuple
)):
dataset
=
_files_to_dataset
(
matched_files
)
elif
isinstance
(
matched_files
,
dict
):
dataset
=
{}
for
k
,
fs
in
matched_files
.
items
():
dataset
[
k
]
=
_files_to_dataset
(
fs
)
else
:
raise
ValueError
(
'`matched_files` should be a list or dict.'
)
return
dataset
def
_decode_and_parse_dataset
(
self
,
dataset
:
Union
[
tf
.
data
.
Dataset
,
Dict
[
Text
,
tf
.
data
.
Dataset
]],
batch_size
:
int
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Returns a tf.data.Dataset object after shuffling, decoding, and parsing."""
def
_shuffle_and_decode
(
ds
):
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if
self
.
_is_training
and
not
self
.
_cache
:
ds
=
ds
.
shuffle
(
self
.
_shuffle_buffer_size
,
seed
=
self
.
_seed
)
# Decode
ds
=
_maybe_map_fn
(
ds
,
self
.
_decoder_fn
)
return
ds
dataset
=
tf
.
nest
.
map_structure
(
_shuffle_and_decode
,
dataset
)
if
tf
.
nest
.
is_nested
(
dataset
):
dataset
=
self
.
_combine_fn
(
dataset
)
if
self
.
_sample_fn
is
not
None
:
dataset
=
dataset
.
apply
(
self
.
_sample_fn
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_parser_fn
)
if
self
.
_filter_fn
is
not
None
:
dataset
=
dataset
.
filter
(
self
.
_filter_fn
)
if
self
.
_cache
:
dataset
=
dataset
.
cache
()
if
self
.
_is_training
:
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
shuffle
(
self
.
_shuffle_buffer_size
,
seed
=
self
.
_seed
)
# Applies tf.data service before batching operations. This is useful when
# tf.data service is shared between parallel trainers, and batch size is
# changing between parallel trainers. Then batch size is changing, tf.data
# services will be considered different instances if applied after batching
# operations, which make it difficult to share between parallel trainers.
# However, if there are additional expensive operations in
# self._transform_and_batch_fn and self._postprocess_fn, the entire tf.data
# pipeline could be slowed down. In this case, try to move these dataset
# operations into early stages if possible.
if
(
self
.
_enable_shared_tf_data_service_between_parallel_trainers
and
self
.
_apply_tf_data_service_before_batching
):
dataset
=
self
.
_maybe_apply_data_service
(
dataset
,
input_context
)
if
self
.
_transform_and_batch_fn
is
not
None
:
dataset
=
self
.
_transform_and_batch_fn
(
dataset
,
input_context
)
else
:
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
batch_size
)
if
input_context
else
batch_size
dataset
=
dataset
.
batch
(
per_replica_batch_size
,
drop_remainder
=
self
.
_drop_remainder
)
return
dataset
def
_maybe_apply_data_service
(
self
,
dataset
:
tf
.
data
.
Dataset
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Potentially distributes a dataset."""
if
self
.
_enable_tf_data_service
and
input_context
:
if
self
.
_enable_round_robin_tf_data_service
:
replicas_per_input_pipeline
=
input_context
.
num_replicas_in_sync
//
(
input_context
.
num_input_pipelines
)
base_consumer_index
=
input_context
.
input_pipeline_id
*
(
replicas_per_input_pipeline
)
num_consumers
=
input_context
.
num_input_pipelines
*
(
replicas_per_input_pipeline
)
range_dataset
=
tf
.
data
.
Dataset
.
range
(
replicas_per_input_pipeline
)
tfds_kwargs
=
{
'processing_mode'
:
'parallel_epochs'
,
'service'
:
self
.
_tf_data_service_address
,
'job_name'
:
self
.
_tf_data_service_job_name
,
'num_consumers'
:
num_consumers
}
if
self
.
_enable_shared_tf_data_service_between_parallel_trainers
:
raise
ValueError
(
'Shared tf.data service does not support round-robin'
' tf.data service.'
)
dataset
=
range_dataset
.
map
(
lambda
i
:
dataset
.
apply
(
# pylint: disable=g-long-lambda
tf
.
data
.
experimental
.
service
.
distribute
(
consumer_index
=
base_consumer_index
+
i
,
**
tfds_kwargs
)))
# Use parallel interleave to read multiple batches from a tf.data
# service worker in parallel.
dataset
=
dataset
.
interleave
(
lambda
x
:
x
,
cycle_length
=
replicas_per_input_pipeline
,
num_parallel_calls
=
replicas_per_input_pipeline
,
deterministic
=
True
)
else
:
tfds_kwargs
=
{
'processing_mode'
:
'parallel_epochs'
,
'service'
:
self
.
_tf_data_service_address
,
'job_name'
:
self
.
_tf_data_service_job_name
,
}
if
self
.
_enable_shared_tf_data_service_between_parallel_trainers
:
tfds_kwargs
.
update
({
'processing_mode'
:
tf
.
data
.
experimental
.
service
.
ShardingPolicy
.
OFF
,
'cross_trainer_cache'
:
tf
.
data
.
experimental
.
service
.
CrossTrainerCache
(
trainer_id
=
self
.
_trainer_id
)
})
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
service
.
distribute
(
**
tfds_kwargs
))
return
dataset
def
read
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
dataset
:
Optional
[
tf
.
data
.
Dataset
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Generates a tf.data.Dataset object."""
if
dataset
is
None
:
dataset
=
self
.
_read_data_source
(
self
.
_matched_files
,
self
.
_dataset_fn
,
input_context
)
dataset
=
self
.
_decode_and_parse_dataset
(
dataset
,
self
.
_global_batch_size
,
input_context
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
if
not
(
self
.
_enable_shared_tf_data_service_between_parallel_trainers
and
self
.
_apply_tf_data_service_before_batching
):
dataset
=
self
.
_maybe_apply_data_service
(
dataset
,
input_context
)
if
self
.
_deterministic
is
not
None
:
options
=
tf
.
data
.
Options
()
options
.
deterministic
=
self
.
_deterministic
dataset
=
dataset
.
with_options
(
options
)
if
self
.
_autotune_algorithm
:
options
=
tf
.
data
.
Options
()
options
.
autotune
.
autotune_algorithm
=
(
tf
.
data
.
experimental
.
AutotuneAlgorithm
[
self
.
_autotune_algorithm
])
dataset
=
dataset
.
with_options
(
options
)
return
dataset
.
prefetch
(
self
.
_prefetch_buffer_size
)
models-2.13.1/official/core/registry.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Registry utility."""
def
register
(
registered_collection
,
reg_key
):
"""Register decorated function or class to collection.
Register decorated function or class into registered_collection, in a
hierarchical order. For example, when reg_key="my_model/my_exp/my_config_0"
the decorated function or class is stored under
registered_collection["my_model"]["my_exp"]["my_config_0"].
This decorator is supposed to be used together with the lookup() function in
this file.
Args:
registered_collection: a dictionary. The decorated function or class will be
put into this collection.
reg_key: The key for retrieving the registered function or class. If reg_key
is a string, it can be hierarchical like my_model/my_exp/my_config_0
Returns:
A decorator function
Raises:
KeyError: when function or class to register already exists.
"""
def
decorator
(
fn_or_cls
):
"""Put fn_or_cls in the dictionary."""
if
isinstance
(
reg_key
,
str
):
hierarchy
=
reg_key
.
split
(
"/"
)
collection
=
registered_collection
for
h_idx
,
entry_name
in
enumerate
(
hierarchy
[:
-
1
]):
if
entry_name
not
in
collection
:
collection
[
entry_name
]
=
{}
collection
=
collection
[
entry_name
]
if
not
isinstance
(
collection
,
dict
):
raise
KeyError
(
"Collection path {} at position {} already registered as "
"a function or class."
.
format
(
entry_name
,
h_idx
))
leaf_reg_key
=
hierarchy
[
-
1
]
else
:
collection
=
registered_collection
leaf_reg_key
=
reg_key
if
leaf_reg_key
in
collection
:
raise
KeyError
(
"Function or class {} registered multiple times."
.
format
(
leaf_reg_key
))
collection
[
leaf_reg_key
]
=
fn_or_cls
return
fn_or_cls
return
decorator
def
lookup
(
registered_collection
,
reg_key
):
"""Lookup and return decorated function or class in the collection.
Lookup decorated function or class in registered_collection, in a
hierarchical order. For example, when
reg_key="my_model/my_exp/my_config_0",
this function will return
registered_collection["my_model"]["my_exp"]["my_config_0"].
Args:
registered_collection: a dictionary. The decorated function or class will be
retrieved from this collection.
reg_key: The key for retrieving the registered function or class. If reg_key
is a string, it can be hierarchical like my_model/my_exp/my_config_0
Returns:
The registered function or class.
Raises:
LookupError: when reg_key cannot be found.
"""
if
isinstance
(
reg_key
,
str
):
hierarchy
=
reg_key
.
split
(
"/"
)
collection
=
registered_collection
for
h_idx
,
entry_name
in
enumerate
(
hierarchy
):
if
entry_name
not
in
collection
:
raise
LookupError
(
f
"collection path
{
entry_name
}
at position
{
h_idx
}
is never "
f
"registered. Please make sure the
{
entry_name
}
and its library is "
"imported and linked to the trainer binary."
)
collection
=
collection
[
entry_name
]
return
collection
else
:
if
reg_key
not
in
registered_collection
:
raise
LookupError
(
f
"registration key
{
reg_key
}
is never "
f
"registered. Please make sure the
{
reg_key
}
and its library is "
"imported and linked to the trainer binary."
)
return
registered_collection
[
reg_key
]
models-2.13.1/official/core/registry_test.py
0 → 100644
View file @
472e2f80
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for registry."""
import
tensorflow
as
tf
from
official.core
import
registry
class
RegistryTest
(
tf
.
test
.
TestCase
):
def
test_register
(
self
):
collection
=
{}
@
registry
.
register
(
collection
,
'functions/func_0'
)
def
func_test
():
pass
self
.
assertEqual
(
registry
.
lookup
(
collection
,
'functions/func_0'
),
func_test
)
@
registry
.
register
(
collection
,
'classes/cls_0'
)
class
ClassRegistryKey
:
pass
self
.
assertEqual
(
registry
.
lookup
(
collection
,
'classes/cls_0'
),
ClassRegistryKey
)
@
registry
.
register
(
collection
,
ClassRegistryKey
)
class
ClassRegistryValue
:
pass
self
.
assertEqual
(
registry
.
lookup
(
collection
,
ClassRegistryKey
),
ClassRegistryValue
)
def
test_register_hierarchy
(
self
):
collection
=
{}
@
registry
.
register
(
collection
,
'functions/func_0'
)
def
func_test0
():
pass
@
registry
.
register
(
collection
,
'func_1'
)
def
func_test1
():
pass
@
registry
.
register
(
collection
,
func_test1
)
def
func_test2
():
pass
expected_collection
=
{
'functions'
:
{
'func_0'
:
func_test0
,
},
'func_1'
:
func_test1
,
func_test1
:
func_test2
,
}
self
.
assertEqual
(
collection
,
expected_collection
)
def
test_register_error
(
self
):
collection
=
{}
@
registry
.
register
(
collection
,
'functions/func_0'
)
def
func_test0
():
# pylint: disable=unused-variable
pass
with
self
.
assertRaises
(
KeyError
):
@
registry
.
register
(
collection
,
'functions/func_0/sub_func'
)
def
func_test1
():
# pylint: disable=unused-variable
pass
with
self
.
assertRaises
(
LookupError
):
registry
.
lookup
(
collection
,
'non-exist'
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment