Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
78c43ef1
Commit
78c43ef1
authored
Jul 26, 2021
by
Gunho Park
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
parents
67cfc95b
e3c7e300
Changes
227
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
534 additions
and
160 deletions
+534
-160
README.md
README.md
+15
-0
official/README.md
official/README.md
+4
-18
official/common/distribute_utils.py
official/common/distribute_utils.py
+6
-3
official/common/distribute_utils_test.py
official/common/distribute_utils_test.py
+1
-1
official/common/flags.py
official/common/flags.py
+20
-2
official/core/actions.py
official/core/actions.py
+156
-0
official/core/actions_test.py
official/core/actions_test.py
+81
-0
official/core/base_task.py
official/core/base_task.py
+1
-1
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+8
-5
official/core/config_definitions.py
official/core/config_definitions.py
+16
-15
official/core/export_base.py
official/core/export_base.py
+1
-1
official/core/input_reader.py
official/core/input_reader.py
+78
-45
official/core/train_lib.py
official/core/train_lib.py
+18
-4
official/core/train_utils.py
official/core/train_utils.py
+75
-9
official/modeling/multitask/configs.py
official/modeling/multitask/configs.py
+2
-1
official/modeling/multitask/evaluator.py
official/modeling/multitask/evaluator.py
+22
-21
official/modeling/multitask/evaluator_test.py
official/modeling/multitask/evaluator_test.py
+3
-8
official/modeling/multitask/multitask.py
official/modeling/multitask/multitask.py
+3
-9
official/modeling/multitask/train_lib.py
official/modeling/multitask/train_lib.py
+13
-7
official/modeling/multitask/train_lib_test.py
official/modeling/multitask/train_lib_test.py
+11
-10
No files found.
README.md
View file @
78c43ef1
...
...
@@ -30,3 +30,18 @@ If you want to contribute, please review the [contribution guidelines](https://g
## License
[
Apache License 2.0
](
LICENSE
)
## Citing TensorFlow Model Garden
If you use TensorFlow Model Garden in your research, please cite this repository.
```
@misc{tensorflowmodelgarden2020,
author = {Hongkun Yu and Chen Chen and Xianzhi Du and Yeqing Li and
Abdullah Rashwan and Le Hou and Pengchong Jin and Fan Yang and
Frederick Liu and Jaeyoun Kim and Jing Li},
title = {{TensorFlow Model Garden}},
howpublished = {\url{https://github.com/tensorflow/models}},
year = {2020}
}
```
official/README.md
View file @
78c43ef1
...
...
@@ -40,7 +40,7 @@ In the near future, we will add:
| Model | Reference (Paper) |
|-------|-------------------|
|
[
MNIST
](
vision/image_classification
)
| A basic model to classify digits from the
[
MNIST dataset
](
http://yann.lecun.com/exdb/mnist/
)
|
|
[
ResNet
](
vision/
image_classification
)
|
[
Deep Residual Learning for Image Recognition
](
https://arxiv.org/abs/1512.03385
)
|
|
[
ResNet
](
vision/
beta/MODEL_GARDEN.md
)
|
[
Deep Residual Learning for Image Recognition
](
https://arxiv.org/abs/1512.03385
)
|
|
[
ResNet-RS
](
vision/beta/MODEL_GARDEN.md
)
|
[
Revisiting ResNets: Improved Training and Scaling Strategies
](
https://arxiv.org/abs/2103.07579
)
|
|
[
EfficientNet
](
vision/image_classification
)
|
[
EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks
](
https://arxiv.org/abs/1905.11946
)
|
...
...
@@ -48,10 +48,10 @@ In the near future, we will add:
| Model | Reference (Paper) |
|-------|-------------------|
|
[
RetinaNet
](
vision/
d
et
ection
)
|
[
Focal Loss for Dense Object Detection
](
https://arxiv.org/abs/1708.02002
)
|
|
[
Mask R-CNN
](
vision/
d
et
ection
)
|
[
Mask R-CNN
](
https://arxiv.org/abs/1703.06870
)
|
|
[
RetinaNet
](
vision/
b
et
a/MODEL_GARDEN.md
)
|
[
Focal Loss for Dense Object Detection
](
https://arxiv.org/abs/1708.02002
)
|
|
[
Mask R-CNN
](
vision/
b
et
a/MODEL_GARDEN.md
)
|
[
Mask R-CNN
](
https://arxiv.org/abs/1703.06870
)
|
|
[
ShapeMask
](
vision/detection
)
|
[
ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors
](
https://arxiv.org/abs/1904.03239
)
|
|
[
SpineNet
](
vision/
d
et
ection
)
|
[
SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization
](
https://arxiv.org/abs/1912.05027
)
|
|
[
SpineNet
](
vision/
b
et
a/MODEL_GARDEN.md
)
|
[
SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization
](
https://arxiv.org/abs/1912.05027
)
|
### Natural Language Processing
...
...
@@ -163,17 +163,3 @@ pip3 install tensorflow-text-nightly
## Contributions
If you want to contribute, please review the
[
contribution guidelines
](
https://github.com/tensorflow/models/wiki/How-to-contribute
)
.
## Citing TF Official Model Garden
To cite this repository:
```
@software{tfmodels2020github,
author = {Chen Chen and Xianzhi Du and Le Hou and Jaeyoun Kim and Jing Li and
Yeqing Li and Abdullah Rashwan and Fan Yang and Hongkun Yu},
title = {TensorFlow Official Model Garden},
url = {https://github.com/tensorflow/models/tree/master/official},
year = {2020},
}
```
official/common/distribute_utils.py
View file @
78c43ef1
...
...
@@ -102,8 +102,10 @@ def get_distribution_strategy(distribution_strategy="mirrored",
distribution_strategy: a string specifying which distribution strategy to
use. Accepted values are "off", "one_device", "mirrored",
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
insensitive. "off" means not to use Distribution Strategy; "tpu" means to
use TPUStrategy using `tpu_address`.
insensitive. "tpu" means to use TPUStrategy using `tpu_address`.
"off" means to use the default strategy which is obtained from
tf.distribute.get_strategy (for details on the default strategy, see
https://www.tensorflow.org/guide/distributed_training#default_strategy).
num_gpus: Number of GPUs to run this model.
all_reduce_alg: Optional. Specifies which algorithm to use when performing
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
...
...
@@ -141,7 +143,8 @@ def get_distribution_strategy(distribution_strategy="mirrored",
if
num_gpus
>
1
:
raise
ValueError
(
"When {} GPUs are specified, distribution_strategy "
"flag cannot be set to `off`."
.
format
(
num_gpus
))
return
None
# Return the default distribution strategy.
return
tf
.
distribute
.
get_strategy
()
if
distribution_strategy
==
"tpu"
:
# When tpu_address is an empty string, we communicate with local TPUs.
...
...
official/common/distribute_utils_test.py
View file @
78c43ef1
...
...
@@ -43,7 +43,7 @@ class GetDistributionStrategyTest(tf.test.TestCase):
def
test_no_strategy
(
self
):
ds
=
distribute_utils
.
get_distribution_strategy
(
'off'
)
self
.
assertIs
None
(
ds
)
self
.
assertIs
(
ds
,
tf
.
distribute
.
get_strategy
()
)
def
test_invalid_strategy
(
self
):
with
self
.
assertRaisesRegexp
(
...
...
official/common/flags.py
View file @
78c43ef1
...
...
@@ -18,9 +18,27 @@ from absl import flags
def
define_flags
():
"""Defines flags."""
"""Defines flags.
All flags are defined as optional, but in practice most models use some of
these flags and so mark_flags_as_required() should be called after calling
this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
For example:
```
from absl import flags
from official.common import flags as tfm_flags # pylint: disable=line-too-long
...
tfm_flags.define_flags()
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
```
The reason all flags are optional is because unit tests often do not set or
use any of the flags.
"""
flags
.
DEFINE_string
(
'experiment'
,
default
=
None
,
help
=
'The experiment type registered.'
)
'experiment'
,
default
=
None
,
help
=
'The experiment type registered, specifying an ExperimentConfig.'
)
flags
.
DEFINE_enum
(
'mode'
,
...
...
official/core/actions.py
0 → 100644
View file @
78c43ef1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides TFM orbit actions and associated helper functions/classes."""
import
os
from
typing
import
List
import
gin
import
orbit
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
from
official.core
import
base_trainer
from
official.core
import
config_definitions
from
official.modeling
import
optimization
class
PruningActions
:
"""Train action to updates pruning related information.
This action updates pruning steps at the end of trainig loop, and log
pruning metrics to tensorboard.
This action must be used when training a pruned model to avoid pruning error.
"""
def
__init__
(
self
,
export_dir
:
str
,
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the pruning summaries.
model: `tf.keras.Model` model instance used for training. This will be
used to assign a pruning step to each prunable weight.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to find the current training steps.
"""
self
.
_optimizer
=
optimizer
self
.
update_pruning_step
=
tfmot
.
sparsity
.
keras
.
UpdatePruningStep
()
self
.
update_pruning_step
.
set_model
(
model
)
self
.
update_pruning_step
.
on_train_begin
()
self
.
pruning_summaries
=
tfmot
.
sparsity
.
keras
.
PruningSummaries
(
log_dir
=
export_dir
)
model
.
optimizer
=
optimizer
self
.
pruning_summaries
.
set_model
(
model
)
def
__call__
(
self
,
output
:
orbit
.
runner
.
Output
):
"""Update pruning step and log pruning summaries.
Args:
output: The train output to test.
"""
self
.
update_pruning_step
.
on_epoch_end
(
batch
=
None
)
self
.
pruning_summaries
.
on_epoch_begin
(
epoch
=
None
)
class
EMACheckpointing
:
"""Eval action to save checkpoint with average weights when EMA is used.
This action swaps the weights of the model with the average weights, then it
saves the checkpoint under export_dir/ema_checkpoints. Checkpointing is
expensive for large models, so doing this action in eval is more efficient
than training.
"""
def
__init__
(
self
,
export_dir
:
str
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
checkpoint
:
tf
.
train
.
Checkpoint
,
max_to_keep
:
int
=
1
):
"""Initializes the instance.
Args:
export_dir: `str` for the export directory of the EMA average weights.
optimizer: `tf.keras.optimizers.Optimizer` optimizer instance used for
training. This will be used to swap the model weights with the average
weigths.
checkpoint: `tf.train.Checkpoint` instance.
max_to_keep: `int` for max checkpoints to keep in ema_checkpoints subdir.
"""
if
not
isinstance
(
optimizer
,
optimization
.
ExponentialMovingAverage
):
raise
ValueError
(
'Optimizer has to be instance of'
'optimization.ExponentialMovingAverage for'
'EMACheckpointing action'
)
export_dir
=
os
.
path
.
join
(
export_dir
,
'ema_checkpoints'
)
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
export_dir
))
self
.
_optimizer
=
optimizer
self
.
_checkpoint
=
checkpoint
self
.
_checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
export_dir
,
max_to_keep
=
max_to_keep
,
checkpoint_name
=
'average_weights'
)
def
__call__
(
self
,
output
:
orbit
.
runner
.
Output
):
"""Swaps model weights, and saves the checkpoint.
Args:
output: The train or eval output to test.
"""
self
.
_optimizer
.
swap_weights
()
self
.
_checkpoint_manager
.
save
(
checkpoint_number
=
self
.
_optimizer
.
iterations
)
self
.
_optimizer
.
swap_weights
()
@
gin
.
configurable
def
get_eval_actions
(
params
:
config_definitions
.
ExperimentConfig
,
trainer
:
base_trainer
.
Trainer
,
model_dir
:
str
)
->
List
[
orbit
.
Action
]:
"""Gets eval actions for TFM trainer."""
eval_actions
=
[]
# Adds ema checkpointing action to save the average weights under
# ema_checkpoints subdir.
if
isinstance
(
trainer
.
optimizer
,
optimization
.
ExponentialMovingAverage
):
eval_actions
.
append
(
EMACheckpointing
(
export_dir
=
model_dir
,
optimizer
=
trainer
.
optimizer
,
checkpoint
=
trainer
.
checkpoint
,
max_to_keep
=
params
.
trainer
.
max_to_keep
))
return
eval_actions
@
gin
.
configurable
def
get_train_actions
(
params
:
config_definitions
.
ExperimentConfig
,
trainer
:
base_trainer
.
Trainer
,
model_dir
:
str
)
->
List
[
orbit
.
Action
]:
"""Gets train actions for TFM trainer."""
train_actions
=
[]
# Adds pruning callback actions.
if
hasattr
(
params
.
task
,
'pruning'
):
train_actions
.
append
(
PruningActions
(
export_dir
=
model_dir
,
model
=
trainer
.
model
,
optimizer
=
trainer
.
optimizer
))
return
train_actions
official/core/actions_test.py
0 → 100644
View file @
78c43ef1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for TFM actions."""
import
os
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
actions
from
official.modeling
import
optimization
class
TestModel
(
tf
.
Module
):
def
__init__
(
self
):
self
.
value
=
tf
.
Variable
(
0
)
@
tf
.
function
(
input_signature
=
[])
def
__call__
(
self
):
return
self
.
value
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],)
class
ActionsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_ema_checkpointing
(
self
,
distribution
):
with
distribution
.
scope
():
directory
=
self
.
create_tempdir
()
model
=
TestModel
()
optimizer
=
tf
.
keras
.
optimizers
.
SGD
()
optimizer
=
optimization
.
ExponentialMovingAverage
(
optimizer
,
trainable_weights_only
=
False
)
# Creats average weights for the model variables. Average weights are
# initialized to zero.
optimizer
.
shadow_copy
(
model
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
# Changes model.value to 3, average value is still 0.
model
.
value
.
assign
(
3
)
# Checks model.value is 3
self
.
assertEqual
(
model
(),
3
)
ema_action
=
actions
.
EMACheckpointing
(
directory
,
optimizer
,
checkpoint
)
ema_action
({})
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
directory
,
'ema_checkpoints'
)))
checkpoint
.
read
(
tf
.
train
.
latest_checkpoint
(
os
.
path
.
join
(
directory
,
'ema_checkpoints'
)))
# Checks model.value is 0 after swapping.
self
.
assertEqual
(
model
(),
0
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/core/base_task.py
View file @
78c43ef1
...
...
@@ -79,7 +79,7 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if
runtime_config
and
runtime_config
.
loss_scale
:
if
runtime_config
:
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
use_float16
=
runtime_config
.
mixed_precision_dtype
==
"float16"
,
...
...
official/core/base_trainer_test.py
View file @
78c43ef1
...
...
@@ -303,13 +303,16 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
},
})))
trainer
=
self
.
create_test_trainer
(
config
)
if
mixed_precision_dtype
!=
'float16'
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
elif
mixed_precision_dtype
==
'float16'
and
loss_scale
is
None
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
else
:
if
mixed_precision_dtype
==
'float16'
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
)
if
loss_scale
in
(
None
,
'dynamic'
):
self
.
assertTrue
(
trainer
.
optimizer
.
dynamic
)
else
:
self
.
assertFalse
(
trainer
.
optimizer
.
dynamic
)
self
.
assertEqual
(
trainer
.
optimizer
.
initial_scale
,
loss_scale
)
else
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
metrics
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
metrics
)
...
...
official/core/config_definitions.py
View file @
78c43ef1
...
...
@@ -29,12 +29,13 @@ class DataConfig(base_config.Config):
"""The base configuration for building datasets.
Attributes:
input_path: The path to the input. It can be either (1) a str indicating
a file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or
(3) a list of str, each of which is a file path/pattern or multiple file
paths/patterns separated by comma.
It should not be specified when the following `tfds_name` is specified.
input_path: The path to the input. It can be either (1) a str indicating a
file path/pattern, or (2) a str indicating multiple file paths/patterns
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of
str, each of which is a file path/pattern or multiple file paths/patterns
separated by comma, or (4) a dictionary of the previous three approaches
for more advanced data mixing using named access. It should not be
specified when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
...
...
@@ -46,8 +47,8 @@ class DataConfig(base_config.Config):
shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. If `True`, we will cache the
dataset after applying the decode_fn and parse_fn. It can be used to avoid
re-reading from disk, re-decoding and re-parsing the example on the
second
epoch, but it requires significant memory overhead.
re-reading from disk, re-decoding and re-parsing the example on the
second
epoch, but it requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
block_length: The number of consecutive elements to produce from each input
...
...
@@ -59,11 +60,10 @@ class DataConfig(base_config.Config):
tf_data_service_address: The URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This
argument makes it possible for multiple datasets to share the same job.
The default behavior is that the dataset creates anonymous, exclusively
owned jobs.
overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This argument
makes it possible for multiple datasets to share the same job. The default
behavior is that the dataset creates anonymous, exclusively owned jobs.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
returned tf.data.Dataset will have a 2-tuple structure (input, label)
...
...
@@ -75,7 +75,7 @@ class DataConfig(base_config.Config):
performance.
seed: An optional seed to use for deterministic shuffling/preprocessing.
"""
input_path
:
Union
[
Sequence
[
str
],
str
]
=
""
input_path
:
Union
[
Sequence
[
str
],
str
,
base_config
.
Config
]
=
""
tfds_name
:
str
=
""
tfds_split
:
str
=
""
global_batch_size
:
int
=
0
...
...
@@ -239,9 +239,10 @@ class TrainerConfig(base_config.Config):
@
dataclasses
.
dataclass
class
TaskConfig
(
base_config
.
Config
):
init_checkpoint
:
str
=
""
model
:
base_config
.
Config
=
None
model
:
Optional
[
base_config
.
Config
]
=
None
train_data
:
DataConfig
=
DataConfig
()
validation_data
:
DataConfig
=
DataConfig
()
name
:
Optional
[
str
]
=
None
@
dataclasses
.
dataclass
...
...
official/core/export_base.py
View file @
78c43ef1
...
...
@@ -82,7 +82,7 @@ def export(export_module: ExportModule,
The savedmodel directory path.
"""
ckpt_dir_or_file
=
checkpoint_path
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
if
ckpt_dir_or_file
is
not
None
and
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
ckpt_dir_or_file
:
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
export_module
.
model
)
...
...
official/core/input_reader.py
View file @
78c43ef1
...
...
@@ -14,7 +14,7 @@
"""A common dataset reader."""
import
random
from
typing
import
Any
,
Callable
,
List
,
Optional
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
,
Dict
,
Sequence
from
absl
import
logging
import
tensorflow
as
tf
...
...
@@ -45,6 +45,7 @@ class InputReader:
params
:
cfg
.
DataConfig
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
decoder_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
combine_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
sample_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
parser_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
transform_and_batch_fn
:
Optional
[
Callable
[
...
...
@@ -59,6 +60,9 @@ class InputReader:
example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary.
combine_fn: An optional `callable` that takes a dictionarty of
`tf.data.Dataset` objects as input and outputs a combined dataset. It
will be executed after the decoder_fn and before the sample_fn.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn.
...
...
@@ -78,10 +82,23 @@ class InputReader:
raise
ValueError
(
'At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.'
%
(
params
.
input_path
,
params
.
tfds_name
))
if
isinstance
(
params
.
input_path
,
cfg
.
base_config
.
Config
)
and
combine_fn
is
None
:
raise
ValueError
(
'A `combine_fn` is required if the `input_path` is a dictionary.'
)
self
.
_tfds_builder
=
None
self
.
_matched_files
=
[]
self
.
_matched_files
=
None
if
params
.
input_path
:
self
.
_matched_files
=
self
.
_match_files
(
params
.
input_path
)
# we want to combine / mix datasets
if
isinstance
(
params
.
input_path
,
cfg
.
base_config
.
Config
):
self
.
_matched_files
=
{}
for
k
,
v
in
params
.
input_path
.
as_dict
().
items
():
self
.
_matched_files
[
k
]
=
self
.
_match_files
(
v
)
# single dataset
else
:
self
.
_matched_files
=
self
.
_match_files
(
params
.
input_path
)
else
:
# Read dataset from TFDS.
if
not
params
.
tfds_split
:
...
...
@@ -106,18 +123,20 @@ class InputReader:
self
.
_dataset_fn
=
dataset_fn
self
.
_decoder_fn
=
decoder_fn
self
.
_combine_fn
=
combine_fn
self
.
_sample_fn
=
sample_fn
self
.
_parser_fn
=
parser_fn
self
.
_transform_and_batch_fn
=
transform_and_batch_fn
self
.
_postprocess_fn
=
postprocess_fn
self
.
_seed
=
params
.
seed
# When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None.
if
params
.
seed
is
not
None
:
self
.
_seed
=
params
.
seed
elif
params
.
enable_tf_data_service
:
self
.
_seed
=
_get_random_integer
()
else
:
# Sharding should also be disabled because tf data service handles how
# each worker shard data with `processing_mode` in distribute method.
if
params
.
enable_tf_data_service
:
self
.
_seed
=
None
self
.
_sharding
=
False
self
.
_enable_tf_data_service
=
(
params
.
enable_tf_data_service
and
params
.
tf_data_service_address
)
...
...
@@ -130,7 +149,7 @@ class InputReader:
self
.
_enable_round_robin_tf_data_service
=
params
.
get
(
'enable_round_robin_tf_data_service'
,
False
)
def
_match_files
(
self
,
input_path
:
str
)
->
List
[
str
]:
def
_match_files
(
self
,
input_path
:
Union
[
Sequence
[
str
],
str
]
)
->
List
[
str
]:
"""Matches files from an input_path."""
matched_files
=
[]
# Read dataset from files.
...
...
@@ -181,16 +200,21 @@ class InputReader:
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway.
if
self
.
_is_training
:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if
self
.
_sharding
and
self
.
_seed
is
None
:
seed
=
_get_random_integer
()
else
:
seed
=
self
.
_seed
dataset
=
dataset
.
shuffle
(
len
(
matched_files
),
seed
=
self
.
_
seed
,
seed
=
seed
,
reshuffle_each_iteration
=
True
if
not
self
.
_cache
else
False
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
and
not
self
.
_enable_tf_data_service
):
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
...
...
@@ -225,9 +249,8 @@ class InputReader:
dataset
=
dataset
.
with_options
(
options
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
and
not
self
.
_enable_tf_data_service
):
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
...
...
@@ -276,42 +299,53 @@ class InputReader:
def
_read_decode_and_parse_dataset
(
self
,
matched_files
:
List
[
str
],
matched_files
:
Union
[
Dict
[
str
,
List
[
str
]],
List
[
str
]
],
dataset_fn
,
batch_size
:
int
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
tfds_builder
:
bool
=
False
)
->
tf
.
data
.
Dataset
:
"""Returns a tf.data.Dataset object after reading, decoding, and parsing."""
def
_files_to_dataset
(
files
:
List
[
str
])
->
tf
.
data
.
Dataset
:
if
len
(
files
)
>
1
:
if
input_context
and
(
len
(
files
)
<
input_context
.
num_input_pipelines
):
logging
.
warn
(
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.'
,
len
(
files
),
input_context
.
num_input_pipelines
)
return
self
.
_read_files_then_shard
(
files
,
dataset_fn
,
input_context
)
else
:
return
self
.
_shard_files_then_read
(
files
,
dataset_fn
,
input_context
)
elif
len
(
files
)
==
1
:
return
self
.
_read_files_then_shard
(
files
,
dataset_fn
,
input_context
)
else
:
raise
ValueError
(
'It is unexpected that `tfds_builder` is None and '
'there is also no `files`.'
)
def
_shuffle_and_decode
(
ds
):
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if
self
.
_is_training
and
not
self
.
_cache
:
ds
=
ds
.
shuffle
(
self
.
_shuffle_buffer_size
,
seed
=
self
.
_seed
)
# Decode
ds
=
_maybe_map_fn
(
ds
,
self
.
_decoder_fn
)
return
ds
if
tfds_builder
:
dataset
=
self
.
_read_tfds
(
input_context
)
elif
len
(
matched_files
)
>
1
:
if
input_context
and
(
len
(
matched_files
)
<
input_context
.
num_input_pipelines
):
logging
.
warn
(
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.'
,
len
(
matched_files
),
input_context
.
num_input_pipelines
)
dataset
=
self
.
_read_files_then_shard
(
matched_files
,
dataset_fn
,
input_context
)
else
:
dataset
=
self
.
_shard_files_then_read
(
matched_files
,
dataset_fn
,
input_context
)
elif
len
(
matched_files
)
==
1
:
dataset
=
self
.
_read_files_then_shard
(
matched_files
,
dataset_fn
,
input_context
)
dataset
=
_shuffle_and_decode
(
dataset
)
elif
isinstance
(
matched_files
,
(
list
,
tuple
)):
dataset
=
_files_to_dataset
(
matched_files
)
dataset
=
_shuffle_and_decode
(
dataset
)
elif
isinstance
(
matched_files
,
dict
):
datasets
=
{}
for
k
,
fs
in
matched_files
.
items
():
datasets
[
k
]
=
_files_to_dataset
(
fs
)
datasets
[
k
]
=
_shuffle_and_decode
(
datasets
[
k
])
dataset
=
self
.
_combine_fn
(
datasets
)
else
:
raise
ValueError
(
'It is unexpected that `tfds_builder` is None and '
'there is also no `matched_files`.'
)
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if
self
.
_is_training
and
not
self
.
_cache
:
dataset
=
dataset
.
shuffle
(
self
.
_shuffle_buffer_size
,
seed
=
self
.
_seed
)
raise
ValueError
(
'`matched_files` should be a list or dict.'
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_decoder_fn
)
if
self
.
_sample_fn
is
not
None
:
dataset
=
dataset
.
apply
(
self
.
_sample_fn
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_parser_fn
)
...
...
@@ -328,8 +362,7 @@ class InputReader:
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
batch_size
)
if
input_context
else
batch_size
dataset
=
dataset
.
batch
(
per_replica_batch_size
,
drop_remainder
=
self
.
_drop_remainder
)
per_replica_batch_size
,
drop_remainder
=
self
.
_drop_remainder
)
return
dataset
...
...
official/core/train_lib.py
View file @
78c43ef1
...
...
@@ -15,13 +15,15 @@
"""TFM common training driver library."""
# pytype: disable=attribute-error
import
os
from
typing
import
Any
,
Mapping
,
Tuple
,
Optional
from
typing
import
Any
,
Mapping
,
Optional
,
Tuple
# Import libraries
from
absl
import
logging
import
orbit
import
tensorflow
as
tf
from
official.core
import
actions
from
official.core
import
base_task
from
official.core
import
base_trainer
from
official.core
import
config_definitions
...
...
@@ -38,7 +40,8 @@ def run_experiment(
model_dir
:
str
,
run_post_eval
:
bool
=
False
,
save_summary
:
bool
=
True
,
trainer
:
Optional
[
base_trainer
.
Trainer
]
=
None
trainer
:
Optional
[
base_trainer
.
Trainer
]
=
None
,
controller_cls
=
orbit
.
Controller
)
->
Tuple
[
tf
.
keras
.
Model
,
Mapping
[
str
,
Any
]]:
"""Runs train/eval configured by the experiment params.
...
...
@@ -54,6 +57,8 @@ def run_experiment(
save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().
controller_cls: The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
Returns:
A 2-tuple of (model, eval_logs).
...
...
@@ -73,6 +78,8 @@ def run_experiment(
params
,
model_dir
))
if
trainer
.
checkpoint
:
if
model_dir
is
None
:
raise
ValueError
(
'model_dir must be specified, but got None'
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
trainer
.
checkpoint
,
directory
=
model_dir
,
...
...
@@ -85,7 +92,7 @@ def run_experiment(
else
:
checkpoint_manager
=
None
controller
=
orbit
.
C
ontroller
(
controller
=
c
ontroller
_cls
(
strategy
=
distribution_strategy
,
trainer
=
trainer
if
'train'
in
mode
else
None
,
evaluator
=
trainer
,
...
...
@@ -97,7 +104,9 @@ def run_experiment(
params
.
trainer
.
validation_summary_subdir
)
if
(
save_summary
)
else
None
,
summary_interval
=
params
.
trainer
.
summary_interval
if
(
save_summary
)
else
None
)
(
save_summary
)
else
None
,
train_actions
=
actions
.
get_train_actions
(
params
,
trainer
,
model_dir
),
eval_actions
=
actions
.
get_eval_actions
(
params
,
trainer
,
model_dir
))
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
with
distribution_strategy
.
scope
():
...
...
@@ -129,6 +138,11 @@ def run_experiment(
logging
.
info
(
'Number of trainable params in model: %f Millions.'
,
num_params
/
10.
**
6
)
flops
=
train_utils
.
try_count_flops
(
trainer
.
model
)
if
flops
is
not
None
:
logging
.
info
(
'FLOPs (multi-adds) in model: %f Billions.'
,
flops
/
10.
**
9
/
2
)
if
run_post_eval
:
with
distribution_strategy
.
scope
():
return
trainer
.
model
,
trainer
.
evaluate
(
...
...
official/core/train_utils.py
View file @
78c43ef1
...
...
@@ -17,7 +17,7 @@ import copy
import
json
import
os
import
pprint
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
absl
import
logging
import
dataclasses
...
...
@@ -25,6 +25,9 @@ import gin
import
orbit
import
tensorflow
as
tf
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.framework.convert_to_constants
import
convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
from
official.core
import
base_task
from
official.core
import
base_trainer
from
official.core
import
config_definitions
...
...
@@ -139,14 +142,19 @@ class BestCheckpointExporter:
return
self
.
_checkpoint_manager
def
maybe_export_checkpoint
(
self
,
checkpoint
,
eval_logs
,
global_step
):
def
maybe_export_checkpoint
(
self
,
checkpoint
,
eval_logs
,
global_step
,
write_logs
=
True
)
->
bool
:
"""Compare eval_logs with past eval_logs and export checkpoint if better."""
logging
.
info
(
'[BestCheckpointExporter] received eval_logs: %s, at step: %d'
,
eval_logs
,
global_step
)
if
self
.
_best_ckpt_logs
is
None
or
self
.
_new_metric_is_better
(
self
.
_best_ckpt_logs
,
eval_logs
):
self
.
_best_ckpt_logs
=
eval_logs
self
.
_export_best_eval_metric
(
checkpoint
,
self
.
_best_ckpt_logs
,
global_step
)
if
write_logs
:
self
.
export_best_eval_metric
(
self
.
_best_ckpt_logs
,
global_step
)
self
.
_get_checkpoint_manager
(
checkpoint
).
save
()
return
True
return
False
def
_maybe_load_best_eval_metric
(
self
):
if
not
tf
.
io
.
gfile
.
exists
(
self
.
best_ckpt_logs_path
):
...
...
@@ -177,7 +185,7 @@ class BestCheckpointExporter:
return
True
return
False
def
_
export_best_eval_metric
(
self
,
checkpoint
,
eval_logs
,
global_step
):
def
export_best_eval_metric
(
self
,
eval_logs
,
global_step
):
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext
=
copy
.
copy
(
eval_logs
)
eval_logs_ext
[
'best_ckpt_global_step'
]
=
global_step
...
...
@@ -187,8 +195,6 @@ class BestCheckpointExporter:
with
tf
.
io
.
gfile
.
GFile
(
self
.
best_ckpt_logs_path
,
'w'
)
as
writer
:
writer
.
write
(
json
.
dumps
(
eval_logs_ext
,
indent
=
4
)
+
'
\n
'
)
self
.
_get_checkpoint_manager
(
checkpoint
).
save
()
@
property
def
best_ckpt_logs
(
self
):
return
self
.
_best_ckpt_logs
...
...
@@ -241,6 +247,9 @@ class ParseConfigOptions:
def
parse_configuration
(
flags_obj
,
lock_return
=
True
,
print_return
=
True
):
"""Parses ExperimentConfig from flags."""
if
flags_obj
.
experiment
is
None
:
raise
ValueError
(
'The flag --experiment must be specified.'
)
# 1. Get the default config from the registered experiment.
params
=
exp_factory
.
get_exp_config
(
flags_obj
.
experiment
)
...
...
@@ -285,7 +294,7 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
if
print_return
:
pp
=
pprint
.
PrettyPrinter
()
logging
.
info
(
'Final experiment parameters:
%s'
,
logging
.
info
(
'Final experiment parameters:
\n
%s'
,
pp
.
pformat
(
params
.
as_dict
()))
return
params
...
...
@@ -294,6 +303,8 @@ def parse_configuration(flags_obj, lock_return=True, print_return=True):
def
serialize_config
(
params
:
config_definitions
.
ExperimentConfig
,
model_dir
:
str
):
"""Serializes and saves the experiment config."""
if
model_dir
is
None
:
raise
ValueError
(
'model_dir must be specified, but got None'
)
params_save_path
=
os
.
path
.
join
(
model_dir
,
'params.yaml'
)
logging
.
info
(
'Saving experiment configuration to %s'
,
params_save_path
)
tf
.
io
.
gfile
.
makedirs
(
model_dir
)
...
...
@@ -369,11 +380,15 @@ def remove_ckpts(model_dir):
tf
.
io
.
gfile
.
remove
(
file_to_remove
)
def
try_count_params
(
model
:
tf
.
keras
.
Model
):
def
try_count_params
(
model
:
Union
[
tf
.
Module
,
tf
.
keras
.
Model
],
trainable_only
:
bool
=
False
):
"""Count the number of parameters if model is possible.
Args:
model: Try to count the number of params in this model.
trainable_only: Whether to calculate trainable params only. This flag is
not used when the model has `count_params` attribute.
Returns:
The number of parameters or None.
...
...
@@ -387,4 +402,55 @@ def try_count_params(model: tf.keras.Model):
'because the model was not feed any input, e.g., the max '
'train step already reached before this run.'
)
return
None
else
:
total_params
=
0
variables
=
model
.
trainable_variables
if
trainable_only
else
model
.
variables
for
var
in
variables
:
shape
=
tf
.
shape
(
var
)
total_params
+=
tf
.
math
.
reduce_prod
(
shape
).
numpy
()
return
total_params
def
try_count_flops
(
model
:
Union
[
tf
.
Module
,
tf
.
keras
.
Model
],
inputs_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
"""Counts and returns model FLOPs.
Args:
model: A model instance.
inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
shape specifications to getting corresponding concrete function.
Returns:
The model's FLOPs.
"""
if
hasattr
(
model
,
'inputs'
):
try
:
# Get input shape and set batch size to 1.
if
model
.
inputs
:
inputs
=
[
tf
.
TensorSpec
([
1
]
+
input
.
shape
[
1
:],
input
.
dtype
)
for
input
in
model
.
inputs
]
concrete_func
=
tf
.
function
(
model
).
get_concrete_function
(
inputs
)
# If model.inputs is invalid, try to use the input to get concrete
# function for model.call (subclass model).
else
:
concrete_func
=
tf
.
function
(
model
.
call
).
get_concrete_function
(
**
inputs_kwargs
)
frozen_func
,
_
=
convert_variables_to_constants_v2_as_graph
(
concrete_func
)
# Calculate FLOPs.
run_meta
=
tf
.
compat
.
v1
.
RunMetadata
()
opts
=
tf
.
compat
.
v1
.
profiler
.
ProfileOptionBuilder
.
float_operation
()
opts
[
'output'
]
=
'none'
flops
=
tf
.
compat
.
v1
.
profiler
.
profile
(
graph
=
frozen_func
.
graph
,
run_meta
=
run_meta
,
options
=
opts
)
return
flops
.
total_float_ops
except
Exception
as
e
:
# pylint: disable=broad-except
logging
.
info
(
'Failed to count model FLOPs with error %s, because the build() '
'methods in keras layers were not called. This is probably because '
'the model was not feed any input, e.g., the max train step already '
'reached before this run.'
,
e
)
return
None
return
None
official/modeling/multitask/configs.py
View file @
78c43ef1
...
...
@@ -23,6 +23,7 @@ from official.modeling import hyperparams
@
dataclasses
.
dataclass
class
TaskRoutine
(
hyperparams
.
Config
):
# TODO(hongkuny): deprecate the task_name once we migrated client code.
task_name
:
str
=
""
task_config
:
cfg
.
TaskConfig
=
None
eval_steps
:
Optional
[
int
]
=
None
...
...
@@ -76,4 +77,4 @@ class MultiEvalExperimentConfig(cfg.ExperimentConfig):
Attributes:
eval_tasks: individual evaluation tasks.
"""
eval_tasks
:
MultiTaskConfig
=
MultiTaskConfig
()
eval_tasks
:
Tuple
[
TaskRoutine
,
...]
=
()
official/modeling/multitask/evaluator.py
View file @
78c43ef1
...
...
@@ -16,14 +16,14 @@
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
from
typing
import
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
train_utils
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
multitask
@
gin
.
configurable
...
...
@@ -32,37 +32,39 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
def
__init__
(
self
,
task
:
multitask
.
Multi
Task
,
eval_
task
s
:
List
[
base_task
.
Task
]
,
model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
global_step
:
Optional
[
tf
.
Variable
]
=
None
,
eval_steps
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
checkpoint_exporter
:
Optional
[
train_utils
.
BestCheckpointExporter
]
=
None
):
"""Initialize common trainer for TensorFlow models.
Args:
task: A
multitask.MultiTask instanc
e.
eval_
task
s
: A
list of tasks to evaluat
e.
model: tf.keras.Model instance.
global_step: the global step variable.
eval_steps: a dictionary of steps to run eval keyed by task names.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_task
=
task
self
.
_task
s
=
eval_
task
s
self
.
_model
=
model
self
.
_global_step
=
global_step
or
orbit
.
utils
.
create_global_step
()
self
.
_checkpoint_exporter
=
checkpoint_exporter
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
model
=
self
.
model
)
global_step
=
self
.
global_step
,
model
=
self
.
model
)
self
.
_validation_losses
=
None
self
.
_validation_metrics
=
None
# Builds per-task datasets.
self
.
eval_datasets
=
{}
for
name
,
task
in
self
.
task
.
tasks
.
items
():
self
.
eval_datasets
[
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
eval_steps
=
eval_steps
or
{}
for
task
in
self
.
tasks
:
self
.
eval_datasets
[
task
.
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
validation_data
)
# Builds per-task validation loops.
...
...
@@ -89,8 +91,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return
orbit
.
utils
.
create_loop_fn
(
eval_step_fn
)
self
.
task_fns
=
{
name
:
get_function
(
name
,
task
)
for
name
,
task
in
self
.
task
.
tasks
.
items
()
task
.
name
:
get_function
(
task
.
name
,
task
)
for
task
in
self
.
tasks
}
@
property
...
...
@@ -98,8 +99,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return
self
.
_strategy
@
property
def
task
(
self
):
return
self
.
_task
def
task
s
(
self
):
return
self
.
_task
s
@
property
def
model
(
self
):
...
...
@@ -115,8 +116,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if
self
.
_validation_losses
is
None
:
# Builds the per-task metrics and losses.
self
.
_validation_losses
=
{}
for
name
in
self
.
task
.
tasks
:
self
.
_validation_losses
[
name
]
=
tf
.
keras
.
metrics
.
Mean
(
for
task
in
self
.
tasks
:
self
.
_validation_losses
[
task
.
name
]
=
tf
.
keras
.
metrics
.
Mean
(
"validation_loss"
,
dtype
=
tf
.
float32
)
return
self
.
_validation_losses
...
...
@@ -126,8 +127,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if
self
.
_validation_metrics
is
None
:
# Builds the per-task metrics and losses.
self
.
_validation_metrics
=
{}
for
name
,
task
in
self
.
task
.
tasks
.
items
()
:
self
.
_validation_metrics
[
name
]
=
task
.
build_metrics
(
training
=
False
)
for
task
in
self
.
task
s
:
self
.
_validation_metrics
[
task
.
name
]
=
task
.
build_metrics
(
training
=
False
)
return
self
.
_validation_metrics
@
property
...
...
@@ -145,12 +146,12 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
results
=
{}
eval_iters
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_datasets
)
for
name
,
task_eval_loop
in
self
.
task
_fns
.
items
()
:
for
task
in
self
.
task
s
:
outputs
=
None
name
=
task
.
name
eval_iter
=
eval_iters
[
name
]
task
=
self
.
task
.
tasks
[
name
]
task_eval_steps
=
self
.
task
.
task_eval_steps
(
name
)
or
num_steps
outputs
=
task_eval_loop
(
task_eval_steps
=
self
.
eval_steps
.
get
(
name
,
None
)
or
num_steps
outputs
=
self
.
task_fns
[
name
](
eval_iter
,
task_eval_steps
,
state
=
outputs
,
...
...
official/modeling/multitask/evaluator_test.py
View file @
78c43ef1
...
...
@@ -22,7 +22,6 @@ from tensorflow.python.distribute import strategy_combinations
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.modeling.multitask
import
evaluator
from
official.modeling.multitask
import
multitask
def
all_strategy_combinations
():
...
...
@@ -89,9 +88,7 @@ class MockTask(base_task.Task):
np
.
concatenate
([
np
.
expand_dims
(
v
.
numpy
(),
axis
=
0
)
for
v
in
value
]))
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
for
k
,
v
in
aggregated_logs
.
items
():
aggregated_logs
[
k
]
=
np
.
sum
(
np
.
stack
(
v
,
axis
=
0
))
return
aggregated_logs
...
...
@@ -106,10 +103,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
task
=
test_multi
task
,
model
=
model
)
eval_tasks
=
task
s
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"foo"
].
keys
())
...
...
@@ -123,10 +119,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
task
=
test_multi
task
,
model
=
model
)
eval_tasks
=
task
s
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
results
[
"bar"
][
"counter"
],
5.
*
distribution
.
num_replicas_in_sync
)
...
...
official/modeling/multitask/multitask.py
View file @
78c43ef1
...
...
@@ -59,10 +59,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
else
:
raise
ValueError
(
"The tasks argument has an invalid type: %s"
%
type
(
tasks
))
self
.
_task_eval_steps
=
task_eval_steps
or
{}
self
.
_task_eval_steps
=
dict
([
(
name
,
self
.
_task_eval_steps
.
get
(
name
,
None
))
for
name
in
self
.
tasks
])
self
.
task_eval_steps
=
task_eval_steps
or
{}
self
.
_task_weights
=
task_weights
or
{}
self
.
_task_weights
=
dict
([
(
name
,
self
.
_task_weights
.
get
(
name
,
1.0
))
for
name
in
self
.
tasks
...
...
@@ -74,9 +71,9 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
task_eval_steps
=
{}
task_weights
=
{}
for
task_routine
in
config
.
task_routines
:
task_name
=
task_routine
.
task_name
task_name
=
task_routine
.
task_name
or
task_routine
.
task_config
.
name
tasks
[
task_name
]
=
task_factory
.
get_task
(
task_routine
.
task_config
,
logging_dir
=
logging_dir
)
task_routine
.
task_config
,
logging_dir
=
logging_dir
,
name
=
task_name
)
task_eval_steps
[
task_name
]
=
task_routine
.
eval_steps
task_weights
[
task_name
]
=
task_routine
.
task_weight
return
cls
(
...
...
@@ -86,9 +83,6 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def
tasks
(
self
):
return
self
.
_tasks
def
task_eval_steps
(
self
,
task_name
):
return
self
.
_task_eval_steps
[
task_name
]
def
task_weight
(
self
,
task_name
):
return
self
.
_task_weights
[
task_name
]
...
...
official/modeling/multitask/train_lib.py
View file @
78c43ef1
...
...
@@ -15,7 +15,7 @@
"""Multitask training driver library."""
# pytype: disable=attribute-error
import
os
from
typing
import
Optional
from
typing
import
List
,
Optional
from
absl
import
logging
import
orbit
import
tensorflow
as
tf
...
...
@@ -69,9 +69,11 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
trainer
=
TRAINERS
[
params
.
trainer
.
trainer_type
](
**
kwargs
)
if
is_training
else
None
if
is_eval
:
eval_steps
=
task
.
task_eval_steps
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
task
=
task
,
eval_
task
s
=
task
.
tasks
.
values
()
,
model
=
model
,
eval_steps
=
eval_steps
,
global_step
=
trainer
.
global_step
if
is_training
else
None
,
checkpoint_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
...
...
@@ -137,7 +139,7 @@ def run_experiment_with_multitask_eval(
*
,
distribution_strategy
:
tf
.
distribute
.
Strategy
,
train_task
:
base_task
.
Task
,
eval_tasks
:
multitask
.
Multi
Task
,
eval_tasks
:
List
[
base_task
.
Task
]
,
mode
:
str
,
params
:
configs
.
MultiEvalExperimentConfig
,
model_dir
:
str
,
...
...
@@ -149,7 +151,7 @@ def run_experiment_with_multitask_eval(
Args:
distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance.
eval_tasks: A
multitask.MultiTask with
evaluation tasks.
eval_tasks: A
list of
evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: MultiEvalExperimentConfig instance.
...
...
@@ -173,8 +175,8 @@ def run_experiment_with_multitask_eval(
config
=
params
,
task
=
train_task
,
model
=
train_task
.
build_model
(),
optimizer
=
train_task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
params
.
runtime
),
optimizer
=
train_task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
params
.
runtime
),
train
=
True
,
evaluate
=
False
)
else
:
...
...
@@ -182,10 +184,14 @@ def run_experiment_with_multitask_eval(
model
=
trainer
.
model
if
trainer
else
train_task
.
build_model
()
if
is_eval
:
eval_steps
=
dict
([(
task_routine
.
task_config
.
name
,
task_routine
.
eval_steps
)
for
task_routine
in
params
.
eval_tasks
])
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
task
=
eval_tasks
,
eval_
task
s
=
eval_tasks
,
model
=
model
,
global_step
=
trainer
.
global_step
if
is_training
else
None
,
eval_steps
=
eval_steps
,
checkpoint_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
else
:
...
...
official/modeling/multitask/train_lib_test.py
View file @
78c43ef1
...
...
@@ -65,8 +65,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
task
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
'foo'
,
task_config
=
test_utils
.
FooConfig
()),
task_name
=
'foo'
,
task_config
=
test_utils
.
FooConfig
()),
configs
.
TaskRoutine
(
task_name
=
'bar'
,
task_config
=
test_utils
.
BarConfig
()))))
experiment_config
=
params_dict
.
override_params_dict
(
...
...
@@ -95,18 +94,20 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
model_dir
=
self
.
get_temp_dir
()
experiment_config
=
configs
.
MultiEvalExperimentConfig
(
task
=
test_utils
.
FooConfig
(),
eval_tasks
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
'foo'
,
task_config
=
test_utils
.
FooConfig
()),
configs
.
TaskRoutine
(
task_name
=
'bar'
,
task_config
=
test_utils
.
BarConfig
()))))
eval_tasks
=
(
configs
.
TaskRoutine
(
task_name
=
'foo'
,
task_config
=
test_utils
.
FooConfig
(),
eval_steps
=
2
),
configs
.
TaskRoutine
(
task_name
=
'bar'
,
task_config
=
test_utils
.
BarConfig
(),
eval_steps
=
3
)))
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
,
self
.
_test_config
,
is_strict
=
False
)
with
distribution_strategy
.
scope
():
train_task
=
task_factory
.
get_task
(
experiment_config
.
task
)
eval_tasks
=
multitask
.
MultiTask
.
from_config
(
experiment_config
.
eval_tasks
)
eval_tasks
=
[
task_factory
.
get_task
(
config
.
task_config
,
name
=
config
.
task_name
)
for
config
in
experiment_config
.
eval_tasks
]
train_lib
.
run_experiment_with_multitask_eval
(
distribution_strategy
=
distribution_strategy
,
train_task
=
train_task
,
...
...
Prev
1
2
3
4
5
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment