Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f16a7b5b
Unverified
Commit
f16a7b5b
authored
May 04, 2021
by
vedanshu
Committed by
GitHub
May 04, 2021
Browse files
Merge pull request
#1
from tensorflow/master
new pull
parents
8e9296ff
8f58f396
Changes
298
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1591 additions
and
304 deletions
+1591
-304
official/modeling/hyperparams/__init__.py
official/modeling/hyperparams/__init__.py
+2
-3
official/modeling/hyperparams/base_config.py
official/modeling/hyperparams/base_config.py
+41
-19
official/modeling/hyperparams/base_config_test.py
official/modeling/hyperparams/base_config_test.py
+68
-7
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+10
-178
official/modeling/hyperparams/oneof.py
official/modeling/hyperparams/oneof.py
+6
-11
official/modeling/hyperparams/oneof_test.py
official/modeling/hyperparams/oneof_test.py
+13
-9
official/modeling/hyperparams/params_dict.py
official/modeling/hyperparams/params_dict.py
+60
-35
official/modeling/hyperparams/params_dict_test.py
official/modeling/hyperparams/params_dict_test.py
+125
-42
official/modeling/multitask/__init__.py
official/modeling/multitask/__init__.py
+14
-0
official/modeling/multitask/base_model.py
official/modeling/multitask/base_model.py
+60
-0
official/modeling/multitask/base_trainer.py
official/modeling/multitask/base_trainer.py
+176
-0
official/modeling/multitask/base_trainer_test.py
official/modeling/multitask/base_trainer_test.py
+90
-0
official/modeling/multitask/configs.py
official/modeling/multitask/configs.py
+79
-0
official/modeling/multitask/evaluator.py
official/modeling/multitask/evaluator.py
+172
-0
official/modeling/multitask/evaluator_test.py
official/modeling/multitask/evaluator_test.py
+138
-0
official/modeling/multitask/interleaving_trainer.py
official/modeling/multitask/interleaving_trainer.py
+92
-0
official/modeling/multitask/interleaving_trainer_test.py
official/modeling/multitask/interleaving_trainer_test.py
+101
-0
official/modeling/multitask/multitask.py
official/modeling/multitask/multitask.py
+148
-0
official/modeling/multitask/task_sampler.py
official/modeling/multitask/task_sampler.py
+121
-0
official/modeling/multitask/task_sampler_test.py
official/modeling/multitask/task_sampler_test.py
+75
-0
No files found.
Too many changes to show.
To preserve performance only
298 of 298+
files are displayed.
Plain diff
Email patch
official/modeling/hyperparams/__init__.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hyperparams package definition."""
# pylint: disable=g-multiple-import
from
official.modeling.hyperparams.base_config
import
*
...
...
official/modeling/hyperparams/base_config.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,17 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base configurations to standardize experiments."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
"""Base configurations to standardize experiments."""
import
copy
import
functools
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Type
from
absl
import
logging
import
dataclasses
import
tensorflow
as
tf
...
...
@@ -35,11 +30,15 @@ from official.modeling.hyperparams import params_dict
class
Config
(
params_dict
.
ParamsDict
):
"""The base configuration class that supports YAML/JSON based overrides.
* It recursively enforces a whitelist of basic types and container types, so
Because of YAML/JSON serialization limitations, some semantics of dataclass
are not supported:
* It recursively enforces a allowlist of basic types and container types, so
it avoids surprises with copy and reuse caused by unanticipated types.
*
I
t converts
d
ict to Config even within sequences,
*
Warning: i
t converts
D
ict to
`
Config
`
even within sequences,
e.g. for config = Config({'key': [([{'a': 42}],)]),
type(config.key[0][0][0]) is Config rather than dict.
If you define/annotate some field as Dict, the field will convert to a
`Config` instance and lose the dictionary type.
"""
# It's safe to add bytes and other immutable types here.
...
...
@@ -142,10 +141,11 @@ class Config(params_dict.ParamsDict):
return
subconfig_type
def
__post_init__
(
self
,
default_params
,
restrictions
,
*
args
,
**
kwargs
):
super
().
__init__
(
default_params
=
default_params
,
restrictions
=
restrictions
,
*
args
,
**
kwargs
)
super
().
__init__
(
default_params
=
default_params
,
restrictions
=
restrictions
,
*
args
,
**
kwargs
)
def
_set
(
self
,
k
,
v
):
"""Overrides same method in ParamsDict.
...
...
@@ -160,13 +160,32 @@ class Config(params_dict.ParamsDict):
RuntimeError
"""
subconfig_type
=
self
.
_get_subconfig_type
(
k
)
if
isinstance
(
v
,
dict
):
def
is_null
(
k
):
if
k
not
in
self
.
__dict__
or
not
self
.
__dict__
[
k
]:
return
True
return
False
if
isinstance
(
v
,
dict
):
if
is_null
(
k
):
# If the key not exist or the value is None, a new Config-family object
# sould be created for the key.
self
.
__dict__
[
k
]
=
subconfig_type
(
v
)
else
:
self
.
__dict__
[
k
].
override
(
v
)
elif
not
is_null
(
k
)
and
isinstance
(
v
,
self
.
SEQUENCE_TYPES
)
and
all
(
[
not
isinstance
(
e
,
self
.
IMMUTABLE_TYPES
)
for
e
in
v
]):
if
len
(
self
.
__dict__
[
k
])
==
len
(
v
):
for
i
in
range
(
len
(
v
)):
self
.
__dict__
[
k
][
i
].
override
(
v
[
i
])
elif
not
all
([
isinstance
(
e
,
self
.
IMMUTABLE_TYPES
)
for
e
in
v
]):
logging
.
warning
(
"The list/tuple don't match the value dictionaries provided. Thus, "
'the list/tuple is determined by the type annotation and '
'values provided. This is error-prone.'
)
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
else
:
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
else
:
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
...
...
@@ -220,16 +239,19 @@ class Config(params_dict.ParamsDict):
}
def
replace
(
self
,
**
kwargs
):
"""Like `override`, but returns a copy with the current config unchanged."""
params
=
self
.
__class__
(
self
)
params
.
override
(
kwargs
,
is_strict
=
True
)
"""Overrides/returns a unlocked copy with the current config unchanged."""
# pylint: disable=protected-access
params
=
copy
.
deepcopy
(
self
)
params
.
_locked
=
False
params
.
_override
(
kwargs
,
is_strict
=
True
)
# pylint: enable=protected-access
return
params
@
classmethod
def
from_yaml
(
cls
,
file_path
:
str
):
# Note: This only works if the Config has all default values.
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'r'
)
as
f
:
loaded
=
yaml
.
load
(
f
)
loaded
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
config
=
cls
()
config
.
override
(
loaded
)
return
config
...
...
official/modeling/hyperparams/base_config_test.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
pprint
from
typing
import
List
,
Tuple
...
...
@@ -45,6 +43,17 @@ class DumpConfig3(DumpConfig2):
g
:
Tuple
[
DumpConfig1
,
...]
=
(
DumpConfig1
(),)
@
dataclasses
.
dataclass
class
DumpConfig4
(
DumpConfig2
):
x
:
int
=
3
@
dataclasses
.
dataclass
class
DummyConfig5
(
base_config
.
Config
):
y
:
Tuple
[
DumpConfig2
,
...]
=
(
DumpConfig2
(),
DumpConfig4
())
z
:
Tuple
[
str
]
=
(
'a'
,)
class
BaseConfigTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
assertHasSameTypes
(
self
,
c
,
d
,
msg
=
''
):
...
...
@@ -106,6 +115,22 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertEqual
(
config
.
g
[
0
].
a
,
4
)
self
.
assertEqual
(
config
.
g
[
0
].
b
,
'new text 3'
)
def
test_replace
(
self
):
config
=
DumpConfig2
()
new_config
=
config
.
replace
(
e
=
{
'a'
:
2
})
self
.
assertEqual
(
new_config
.
e
.
a
,
2
)
self
.
assertIsInstance
(
new_config
.
e
,
DumpConfig1
)
config
=
DumpConfig2
(
e
=
DumpConfig2
())
new_config
=
config
.
replace
(
e
=
{
'c'
:
4
})
self
.
assertEqual
(
new_config
.
e
.
c
,
4
)
self
.
assertIsInstance
(
new_config
.
e
,
DumpConfig2
)
config
=
DumpConfig3
()
new_config
=
config
.
replace
(
g
=
[{
'a'
:
4
,
'b'
:
'new text 3'
}])
self
.
assertIsInstance
(
new_config
.
g
[
0
],
DumpConfig1
)
self
.
assertEqual
(
new_config
.
g
[
0
].
a
,
4
)
@
parameterized
.
parameters
(
(
'_locked'
,
"The key '_locked' is internally reserved."
),
(
'_restrictions'
,
"The key '_restrictions' is internally reserved."
),
...
...
@@ -147,10 +172,8 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
config
=
base_config
.
Config
({
'key'
:
[{
'a'
:
42
}]})
config
.
override
({
'key'
:
[{
'b'
:
43
}]})
self
.
assertEqual
(
config
.
key
[
0
].
b
,
43
)
with
self
.
assertRaisesRegex
(
AttributeError
,
'The key `a` does not exist'
):
_
=
config
.
key
[
0
].
a
with
self
.
assertRaisesRegex
(
KeyError
,
"The key 'b' does not exist"
):
config
.
override
({
'key'
:
[{
'b'
:
43
}]})
@
parameterized
.
parameters
(
(
lambda
x
:
x
,
'Unknown type'
),
...
...
@@ -294,6 +317,44 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
]),
"['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]"
)
def
test_with_restrictions
(
self
):
restrictions
=
[
'e.a<c'
]
config
=
DumpConfig2
(
restrictions
=
restrictions
)
config
.
validate
()
def
test_nested_tuple
(
self
):
config
=
DummyConfig5
()
config
.
override
({
'y'
:
[{
'c'
:
4
,
'd'
:
'new text 3'
,
'e'
:
{
'a'
:
2
}
},
{
'c'
:
0
,
'd'
:
'new text 3'
,
'e'
:
{
'a'
:
2
}
}],
'z'
:
[
'a'
,
'b'
,
'c'
],
})
self
.
assertEqual
(
config
.
y
[
0
].
c
,
4
)
self
.
assertEqual
(
config
.
y
[
1
].
c
,
0
)
self
.
assertIsInstance
(
config
.
y
[
0
],
DumpConfig2
)
self
.
assertIsInstance
(
config
.
y
[
1
],
DumpConfig4
)
self
.
assertSameElements
(
config
.
z
,
[
'a'
,
'b'
,
'c'
])
def
test_override_by_empty_sequence
(
self
):
config
=
DummyConfig5
()
config
.
override
({
'y'
:
[],
'z'
:
(),
},
is_strict
=
True
)
self
.
assertEmpty
(
config
.
y
)
self
.
assertEmpty
(
config
.
z
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/hyperparams/config_definitions.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,124 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common configuration settings."""
from
typing
import
Optional
,
Union
"""Common configuration settings."""
# pylint:disable=wildcard-import
import
dataclasses
from
official.core.config_definitions
import
*
from
official.modeling.hyperparams
import
base_config
from
official.modeling.optimization.configs
import
optimization_config
from
official.utils
import
registry
OptimizationConfig
=
optimization_config
.
OptimizationConfig
@
dataclasses
.
dataclass
class
DataConfig
(
base_config
.
Config
):
"""The base configuration for building datasets.
Attributes:
input_path: The path to the input. It can be either (1) a file pattern, or
(2) multiple file patterns separated by comma. It should not be specified
when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
is required when above `tfds_name` is specified.
global_batch_size: The global batch size across all replicas.
is_training: Whether this data is used for training or not.
drop_remainder: Whether the last batch should be dropped in the case it has
fewer than `global_batch_size` elements.
shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. Can be used to avoid re-reading
from disk on the second epoch. Requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
sharding: Whether sharding is used in the input pipeline.
examples_consume: An `integer` specifying the number of examples it will
produce. If positive, it only takes this number of examples and raises
tf.error.OutOfRangeError after that. Default is -1, meaning it will
exhaust all the examples in the dataset.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
the returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default,
the returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped
for decoding when loading dataset from TFDS. Use comma to separate
multiple features. The main use case is to skip the image/video decoding
for better performance.
"""
input_path
:
str
=
""
tfds_name
:
str
=
""
tfds_split
:
str
=
""
global_batch_size
:
int
=
0
is_training
:
bool
=
None
drop_remainder
:
bool
=
True
shuffle_buffer_size
:
int
=
100
cache
:
bool
=
False
cycle_length
:
int
=
8
block_length
:
int
=
1
sharding
:
bool
=
True
examples_consume
:
int
=
-
1
tfds_data_dir
:
str
=
""
tfds_download
:
bool
=
False
tfds_as_supervised
:
bool
=
False
tfds_skip_decoding_feature
:
str
=
""
@
dataclasses
.
dataclass
class
RuntimeConfig
(
base_config
.
Config
):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
'float16', or 'bfloat16'.
loss_scale: The type of loss scale, or 'float' value. This is used when
setting the mixed precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
"""
distribution_strategy
:
str
=
"mirrored"
enable_xla
:
bool
=
False
gpu_thread_mode
:
Optional
[
str
]
=
None
dataset_num_private_threads
:
Optional
[
int
]
=
None
per_gpu_thread_count
:
int
=
0
tpu
:
Optional
[
str
]
=
None
num_gpus
:
int
=
0
worker_hosts
:
Optional
[
str
]
=
None
task_index
:
int
=
-
1
all_reduce_alg
:
Optional
[
str
]
=
None
num_packs
:
int
=
1
mixed_precision_dtype
:
Optional
[
str
]
=
None
loss_scale
:
Optional
[
Union
[
str
,
float
]]
=
None
run_eagerly
:
bool
=
False
batchnorm_spatial_persistent
:
bool
=
False
# TODO(hongkuny): These configs are used in models that are going to deprecate.
# Once those models are removed, we should delete this file to avoid confusion.
# Users should not use this file anymore.
@
dataclasses
.
dataclass
class
TensorboardConfig
(
base_config
.
Config
):
"""Configuration for Tensorboard.
...
...
@@ -151,75 +44,14 @@ class CallbacksConfig(base_config.Config):
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
"""
enable_checkpoint_and_export
:
bool
=
True
enable_backup_and_restore
:
bool
=
False
enable_tensorboard
:
bool
=
True
enable_time_history
:
bool
=
True
@
dataclasses
.
dataclass
class
TrainerConfig
(
base_config
.
Config
):
"""Configuration for trainer.
Attributes:
optimizer_config: optimizer config, it includes optimizer, learning rate,
and warmup schedule configs.
train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
"""
optimizer_config
:
OptimizationConfig
=
OptimizationConfig
()
train_tf_while_loop
:
bool
=
True
train_tf_function
:
bool
=
True
eval_tf_function
:
bool
=
True
steps_per_loop
:
int
=
1000
summary_interval
:
int
=
1000
checkpoint_interval
:
int
=
1000
max_to_keep
:
int
=
5
continuous_eval_timeout
:
Optional
[
int
]
=
None
train_steps
:
int
=
0
validation_steps
:
Optional
[
int
]
=
None
validation_interval
:
int
=
1000
@
dataclasses
.
dataclass
class
TaskConfig
(
base_config
.
Config
):
model
:
base_config
.
Config
=
None
train_data
:
DataConfig
=
DataConfig
()
validation_data
:
DataConfig
=
DataConfig
()
@
dataclasses
.
dataclass
class
ExperimentConfig
(
base_config
.
Config
):
"""Top-level configuration."""
task
:
TaskConfig
=
TaskConfig
()
trainer
:
TrainerConfig
=
TrainerConfig
()
runtime
:
RuntimeConfig
=
RuntimeConfig
()
_REGISTERED_CONFIGS
=
{}
def
register_config_factory
(
name
):
"""Register ExperimentConfig factory method."""
return
registry
.
register
(
_REGISTERED_CONFIGS
,
name
)
def
get_exp_config_creater
(
exp_name
:
str
):
"""Looks up ExperimentConfig factory methods."""
exp_creater
=
registry
.
lookup
(
_REGISTERED_CONFIGS
,
exp_name
)
return
exp_creater
official/modeling/hyperparams/oneof.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Config class that supports oneof functionality."""
from
typing
import
Optional
...
...
@@ -38,15 +37,12 @@ class OneOfConfig(base_config.Config):
if
self
.
type
is
None
:
return
{
'type'
:
None
}
elif
self
.
__dict__
[
'type'
]
not
in
self
.
__dict__
:
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
__dict__
[
'type'
]))
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
__dict__
[
'type'
]))
else
:
chosen_type
=
self
.
type
chosen_value
=
self
.
__dict__
[
chosen_type
]
return
{
'type'
:
self
.
type
,
chosen_type
:
self
.
_export_config
(
chosen_value
)
}
return
{
'type'
:
self
.
type
,
chosen_type
:
self
.
_export_config
(
chosen_value
)}
def
get
(
self
):
"""Returns selected config based on the value of type.
...
...
@@ -57,6 +53,5 @@ class OneOfConfig(base_config.Config):
if
chosen_type
is
None
:
return
None
if
chosen_type
not
in
self
.
__dict__
:
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
type
))
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
type
))
return
self
.
__dict__
[
chosen_type
]
official/modeling/hyperparams/oneof_test.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -12,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
dataclasses
import
tensorflow
as
tf
...
...
@@ -48,12 +46,18 @@ class Network(base_config.Config):
class
OneOfTest
(
tf
.
test
.
TestCase
):
def
test_to_dict
(
self
):
network_params
=
{
'backbone'
:
{
'type'
:
'resnet'
,
'resnet'
:
{
'model_depth'
:
50
}
},
'output_layer'
:
{
'type'
:
'single'
,
'single'
:
1000
}
}
network_params
=
{
'backbone'
:
{
'type'
:
'resnet'
,
'resnet'
:
{
'model_depth'
:
50
}
},
'output_layer'
:
{
'type'
:
'single'
,
'single'
:
1000
}
}
network_config
=
Network
(
network_params
)
self
.
assertEqual
(
network_config
.
as_dict
(),
network_params
)
...
...
official/modeling/hyperparams/params_dict.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,12 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A parameter dictionary class which supports the nest structure."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
"""A parameter dictionary class which supports the nest structure."""
import
collections
import
copy
...
...
@@ -30,7 +26,8 @@ import yaml
# key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE
=
re
.
compile
(
r
"""
_PARAM_RE
=
re
.
compile
(
r
"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s*
((?P<val>\'(.*?)\' # single quote
...
...
@@ -44,6 +41,26 @@ _PARAM_RE = re.compile(r"""
_CONST_VALUE_RE
=
re
.
compile
(
r
'(\d.*|-\d.*|None)'
)
# Yaml loader with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.
LOADER
=
yaml
.
SafeLoader
LOADER
.
add_implicit_resolver
(
'tag:yaml.org,2002:float'
,
re
.
compile
(
r
'''
^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$'''
,
re
.
X
),
list
(
'-+0123456789.'
))
class
ParamsDict
(
object
):
"""A hyperparameter container class."""
...
...
@@ -72,7 +89,6 @@ class ParamsDict(object):
if
default_params
is
None
:
default_params
=
{}
self
.
override
(
default_params
,
is_strict
=
False
)
self
.
validate
()
def
_set
(
self
,
k
,
v
):
if
isinstance
(
v
,
dict
):
...
...
@@ -138,8 +154,8 @@ class ParamsDict(object):
ValueError: if the ParamsDict instance has been locked.
"""
if
k
in
ParamsDict
.
RESERVED_ATTR
:
raise
AttributeError
(
'The key `{}` is reserved. No change is allowes. '
.
format
(
k
))
raise
AttributeError
(
'The key `{}` is reserved. No change is allowes. '
.
format
(
k
))
if
k
not
in
self
.
__dict__
.
keys
():
raise
AttributeError
(
'The key `{}` does not exist. '
.
format
(
k
))
if
self
.
_locked
:
...
...
@@ -150,13 +166,13 @@ class ParamsDict(object):
"""Override the ParamsDict with a set of given params.
Args:
override_params: a dict or a ParamsDict specifying the parameters to
be
overridden.
override_params: a dict or a ParamsDict specifying the parameters to
be
overridden.
is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict.
If
False, keys in `override_params` can be different from what is
currently
defined in the ParamsDict. In this case, the ParamsDict will
be extended
to include the new keys.
True, keys in `override_params` must be present in the ParamsDict.
If
False, keys in `override_params` can be different from what is
currently
defined in the ParamsDict. In this case, the ParamsDict will
be extended
to include the new keys.
"""
if
self
.
_locked
:
raise
ValueError
(
'The ParamsDict has been locked. No change is allowed.'
)
...
...
@@ -230,7 +246,7 @@ class ParamsDict(object):
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
What it enforces are:
- a.a1 = 1 == b.ccc.a1 =
2
- a.a1 = 1 == b.ccc.a1 =
1
- a.a2 = 2 <= b.bb.bb2 = 20
Raises:
...
...
@@ -240,6 +256,7 @@ class ParamsDict(object):
(2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported.
"""
def
_get_kv
(
dotted_string
,
params_dict
):
"""Get keys and values indicated by dotted_string."""
if
_CONST_VALUE_RE
.
match
(
dotted_string
)
is
not
None
:
...
...
@@ -270,56 +287,64 @@ class ParamsDict(object):
tokens
=
restriction
.
split
(
'=='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
!=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'!='
in
restriction
:
tokens
=
restriction
.
split
(
'!='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
==
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<'
in
restriction
:
tokens
=
restriction
.
split
(
'<'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<='
in
restriction
:
tokens
=
restriction
.
split
(
'<='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>'
in
restriction
:
tokens
=
restriction
.
split
(
'>'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>='
in
restriction
:
tokens
=
restriction
.
split
(
'>='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
else
:
raise
ValueError
(
'Unsupported relation in restriction.'
)
def
read_yaml_to_params_dict
(
file_path
):
def
read_yaml_to_params_dict
(
file_path
:
str
):
"""Reads a YAML file to a ParamsDict."""
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'r'
)
as
f
:
params_dict
=
yaml
.
load
(
f
)
params_dict
=
yaml
.
load
(
f
,
Loader
=
LOADER
)
return
ParamsDict
(
params_dict
)
def
save_params_dict_to_yaml
(
params
,
file_path
):
"""Saves the input ParamsDict to a YAML file."""
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'w'
)
as
f
:
def
_my_list_rep
(
dumper
,
data
):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return
dumper
.
represent_sequence
(
u
'tag:yaml.org,2002:seq'
,
data
,
flow_style
=
True
)
yaml
.
add_representer
(
list
,
_my_list_rep
)
yaml
.
dump
(
params
.
as_dict
(),
f
,
default_flow_style
=
False
)
...
...
@@ -408,8 +433,8 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
Args:
params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
path to
a YAML file specifying the parameters to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
path to
a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not.
Returns:
...
...
@@ -428,12 +453,12 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
nested_csv_str_to_json_str
(
dict_or_string_or_yaml_file
))
except
ValueError
:
pass
params_dict
=
yaml
.
load
(
dict_or_string_or_yaml_file
)
params_dict
=
yaml
.
load
(
dict_or_string_or_yaml_file
,
Loader
=
LOADER
)
if
isinstance
(
params_dict
,
dict
):
params
.
override
(
params_dict
,
is_strict
)
else
:
with
tf
.
io
.
gfile
.
GFile
(
dict_or_string_or_yaml_file
)
as
f
:
params
.
override
(
yaml
.
load
(
f
),
is_strict
)
params
.
override
(
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
),
is_strict
)
else
:
raise
ValueError
(
'Unknown input type to parse.'
)
return
params
official/modeling/hyperparams/params_dict_test.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for params_dict.py."""
...
...
@@ -56,8 +55,7 @@ class ParamsDictTest(tf.test.TestCase):
def
test_setattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
params
.
c
=
'ccc'
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
...
...
@@ -65,17 +63,23 @@ class ParamsDictTest(tf.test.TestCase):
def
test_getattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
c
,
None
)
def
test_delattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
,
'd'
:
{
'd1'
:
1
,
'd2'
:
10
}},
is_strict
=
False
)
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
,
'd'
:
{
'd1'
:
1
,
'd2'
:
10
}
},
is_strict
=
False
)
del
params
.
c
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
...
...
@@ -87,22 +91,26 @@ class ParamsDictTest(tf.test.TestCase):
def
test_contains
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
(
{
'a'
:
'aa'
},
is_strict
=
False
)
params
.
override
({
'a'
:
'aa'
},
is_strict
=
False
)
self
.
assertIn
(
'a'
,
params
)
self
.
assertNotIn
(
'b'
,
params
)
def
test_get
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
(
{
'a'
:
'aa'
},
is_strict
=
False
)
params
.
override
({
'a'
:
'aa'
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
get
(
'a'
),
'aa'
)
self
.
assertEqual
(
params
.
get
(
'b'
,
2
),
2
)
self
.
assertEqual
(
params
.
get
(
'b'
),
None
)
def
test_override_is_strict_true
(
self
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
'cc'
,
'c2'
:
20
}})
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
'cc'
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c1'
:
'ccc'
}},
is_strict
=
True
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c1
,
'ccc'
)
...
...
@@ -112,8 +120,14 @@ class ParamsDictTest(tf.test.TestCase):
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
def
test_override_is_strict_false
(
self
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}})
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c3'
:
3000
}},
is_strict
=
False
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c3
,
3000
)
...
...
@@ -123,8 +137,14 @@ class ParamsDictTest(tf.test.TestCase):
self
.
assertEqual
(
params
.
c
.
c4
,
4444
)
def
test_as_dict
(
self
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}})
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
params_d
=
params
.
as_dict
()
self
.
assertEqual
(
params_d
[
'a'
],
'aa'
)
self
.
assertEqual
(
params_d
[
'b'
],
2
)
...
...
@@ -134,21 +154,27 @@ class ParamsDictTest(tf.test.TestCase):
def
test_validate
(
self
):
# Raise error due to the unknown parameter.
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
1
,
'b'
:
{
'a'
:
11
}},
[
'a == c'
]
)
params
=
params_dict
.
ParamsDict
(
{
'a'
:
1
,
'b'
:
{
'a'
:
11
}},
[
'a == c'
])
params
.
validate
(
)
# OK to check equality of two nested dicts.
params
=
params_dict
.
ParamsDict
(
{
'a'
:
1
,
'b'
:
{
'a'
:
10
},
'c'
:
{
'a'
:
10
}},
[
'b == c'
])
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'a'
:
10
},
'c'
:
{
'a'
:
10
}
},
[
'b == c'
])
# Raise error due to inconsistency
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
1
,
'c'
:
{
'a'
:
10
}},
[
'a == c.a'
]
)
params
=
params_dict
.
ParamsDict
(
{
'a'
:
1
,
'c'
:
{
'a'
:
10
}},
[
'a == c.a'
])
params
.
validate
(
)
# Valid rule.
params
=
params_dict
.
ParamsDict
(
{
'a'
:
1
,
'c'
:
{
'a'
:
1
}},
[
'a == c.a'
])
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'c'
:
{
'a'
:
1
}},
[
'a == c.a'
])
# Overridding violates the existing rule, raise error upon validate.
params
.
override
({
'a'
:
11
})
...
...
@@ -156,12 +182,21 @@ class ParamsDictTest(tf.test.TestCase):
params
.
validate
()
# Valid restrictions with constant.
params
=
params_dict
.
ParamsDict
(
{
'a'
:
None
,
'c'
:
{
'a'
:
1
}},
[
'a == None'
,
'c.a == 1'
])
params
=
params_dict
.
ParamsDict
({
'a'
:
None
,
'c'
:
{
'a'
:
1
}
},
[
'a == None'
,
'c.a == 1'
])
params
.
validate
()
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
4
,
'c'
:
{
'a'
:
1
}},
[
'a == None'
,
'c.a == 1'
])
params
=
params_dict
.
ParamsDict
({
'a'
:
4
,
'c'
:
{
'a'
:
1
}
},
[
'a == None'
,
'c.a == 1'
])
params
.
validate
()
class
ParamsDictIOTest
(
tf
.
test
.
TestCase
):
...
...
@@ -173,8 +208,14 @@ class ParamsDictIOTest(tf.test.TestCase):
return
temp_file
def
test_save_params_dict_to_yaml
(
self
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}})
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
output_yaml_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'params.yaml'
)
params_dict
.
save_params_dict_to_yaml
(
params
,
output_yaml_file
)
...
...
@@ -203,7 +244,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_dict
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_dict
=
{
'b'
:
5.2
,
'c'
:
[
30
,
40
]}
params
=
params_dict
.
override_params_dict
(
params
,
override_dict
,
is_strict
=
True
)
...
...
@@ -215,7 +261,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_yaml_string
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_yaml_string
=
"'b': 5.2
\n
'c': [30, 40]"
params
=
params_dict
.
override_params_dict
(
params
,
override_yaml_string
,
is_strict
=
True
)
...
...
@@ -227,8 +278,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_json_string
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}},
'e'
:
False
})
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],
},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}
},
'e'
:
False
})
override_json_string
=
"{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params
=
params_dict
.
override_params_dict
(
params
,
override_json_string
,
is_strict
=
True
)
...
...
@@ -240,8 +301,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_csv_string
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}},
'e'
:
False
})
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],
},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}
},
'e'
:
False
})
override_csv_string
=
"b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params
=
params_dict
.
override_params_dict
(
params
,
override_csv_string
,
is_strict
=
True
)
...
...
@@ -250,10 +321,23 @@ class ParamsDictIOTest(tf.test.TestCase):
self
.
assertEqual
([
3
,
4
],
params
.
b
.
b2
)
self
.
assertEqual
(
'hi, world'
,
params
.
d
.
d1
.
d2
)
self
.
assertEqual
(
'gs://test'
,
params
.
e
)
# Test different float formats
override_csv_string
=
'b.b2=-1.e-3, d.d1.d2=+0.001, e=1e+3, a=-1.5E-3'
params
=
params_dict
.
override_params_dict
(
params
,
override_csv_string
,
is_strict
=
True
)
self
.
assertEqual
(
-
1e-3
,
params
.
b
.
b2
)
self
.
assertEqual
(
0.001
,
params
.
d
.
d1
.
d2
)
self
.
assertEqual
(
1e3
,
params
.
e
)
self
.
assertEqual
(
-
1.5e-3
,
params
.
a
)
def
test_override_params_dict_using_yaml_file
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_yaml_file
=
self
.
write_temp_file
(
'params.yaml'
,
r
"""
b: 5.2
...
...
@@ -321,8 +405,7 @@ class IOTest(tf.test.TestCase):
def
test_csv_str_load_unsupported_datatypes
(
self
):
csv_str
=
'a=[[1,2,3],[4,5,6]]'
self
.
assertRaises
(
ValueError
,
params_dict
.
nested_csv_str_to_json_str
,
self
.
assertRaises
(
ValueError
,
params_dict
.
nested_csv_str_to_json_str
,
csv_str
)
def
test_csv_str_to_json_str_spacing
(
self
):
...
...
official/modeling/multitask/__init__.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/modeling/multitask/base_model.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstraction of multi-task model."""
from
typing
import
Text
,
Dict
import
tensorflow
as
tf
class
MultiTaskBaseModel
(
tf
.
Module
):
"""Base class that holds multi-task model computation."""
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_sub_tasks
=
self
.
_instantiate_sub_tasks
()
def
_instantiate_sub_tasks
(
self
)
->
Dict
[
Text
,
tf
.
keras
.
Model
]:
"""Abstract function that sets up the computation for each sub-task.
Returns:
A map from task name (as string) to a tf.keras.Model object that
represents the sub-task in the multi-task pool.
"""
raise
NotImplementedError
(
"_instantiate_sub_task_models() is not implemented."
)
@
property
def
sub_tasks
(
self
):
"""Fetch a map of task name (string) to task model (tf.keras.Model)."""
return
self
.
_sub_tasks
def
initialize
(
self
):
"""Optional function that loads a pre-train checkpoint."""
return
official/modeling/multitask/base_trainer.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Multitask base trainer implementation.
The trainer derives from the Orbit `StandardTrainer` class.
"""
from
typing
import
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
multitask
@
gin
.
configurable
class
MultiTaskBaseTrainer
(
orbit
.
StandardTrainer
):
"""Multitask base trainer."""
def
__init__
(
self
,
multi_task
:
multitask
.
MultiTask
,
multi_task_model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
optimizer
:
tf
.
optimizers
.
Optimizer
,
trainer_options
=
None
):
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_multi_task
=
multi_task
self
.
_multi_task_model
=
multi_task_model
self
.
_optimizer
=
optimizer
self
.
_training_losses
=
None
self
.
_training_metrics
=
None
self
.
_global_step
=
orbit
.
utils
.
create_global_step
()
if
hasattr
(
self
.
multi_task_model
,
"checkpoint_items"
):
checkpoint_items
=
self
.
multi_task_model
.
checkpoint_items
else
:
checkpoint_items
=
{}
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
self
.
multi_task_model
,
optimizer
=
self
.
optimizer
,
global_step
=
self
.
global_step
,
**
checkpoint_items
)
train_datasets
=
{}
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
():
train_datasets
[
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
train_data
)
super
().
__init__
(
train_dataset
=
train_datasets
,
options
=
trainer_options
or
orbit
.
StandardTrainerOptions
())
def
train_loop_begin
(
self
):
"""Clean up states that hold losses and metrics."""
for
_
,
train_loss_metric
in
self
.
training_losses
.
items
():
train_loss_metric
.
reset_states
()
for
_
,
metrics
in
self
.
training_metrics
.
items
():
for
metric
in
metrics
:
metric
.
reset_states
()
def
train_loop_end
(
self
):
"""Record loss and metric values per task."""
result
=
{}
for
task_name
,
loss
in
self
.
training_losses
.
items
():
result
[
task_name
]
=
{
loss
.
name
:
loss
.
result
()}
for
task_name
,
task_metrics
in
self
.
training_metrics
.
items
():
result
[
task_name
].
update
(
{
metric
.
name
:
metric
.
result
()
for
metric
in
task_metrics
})
# Note that, the learning rate schedule is managed by the keras optimizer
# internally, which respects the number of backward pass as `iterations`.
# The learning rate schedule does not follow the trainer logical global
# step of multiple tasks.
if
callable
(
self
.
optimizer
.
learning_rate
):
result
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
(
self
.
optimizer
.
iterations
)
else
:
result
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
return
result
@
property
def
checkpoint
(
self
):
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
@
property
def
training_losses
(
self
):
"""Access training loss metric objects for all tasks."""
if
self
.
_training_losses
is
None
:
# Builds the per-task metrics and losses.
# This the total summed training loss of tasks in the joint training.
self
.
_training_losses
=
dict
(
total_loss
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
))
for
name
in
self
.
multi_task
.
tasks
:
self
.
_training_losses
[
name
]
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
)
return
self
.
_training_losses
@
property
def
training_metrics
(
self
):
"""Access training metric metric objects for all tasks."""
if
self
.
_training_metrics
is
None
:
# Builds the per-task metrics and losses.
self
.
_training_metrics
=
{}
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
():
self
.
_training_metrics
[
name
]
=
task
.
build_metrics
(
training
=
True
)
return
self
.
_training_metrics
@
property
def
strategy
(
self
):
return
self
.
_strategy
@
property
def
multi_task
(
self
):
return
self
.
_multi_task
@
property
def
multi_task_model
(
self
):
return
self
.
_multi_task_model
@
property
def
optimizer
(
self
):
return
self
.
_optimizer
@
property
def
global_step
(
self
):
return
self
.
_global_step
def
train_step
(
self
,
iterator_map
):
"""The default train step calling the multi-task train step.
Args:
iterator_map: a dictionary of task names and per-task dataset iterators.
"""
def
step_fn
(
inputs
):
losses
=
self
.
multi_task
.
joint_train_step
(
inputs
,
multi_task_model
=
self
.
multi_task_model
,
optimizer
=
self
.
optimizer
,
task_metrics
=
self
.
training_metrics
)
for
key
,
loss
in
losses
.
items
():
self
.
training_losses
[
key
].
update_state
(
loss
)
self
.
strategy
.
run
(
step_fn
,
args
=
(
tf
.
nest
.
map_structure
(
next
,
iterator_map
),))
self
.
global_step
.
assign_add
(
1
)
official/modeling/multitask/base_trainer_test.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.base_trainer."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.modeling.multitask
import
base_trainer
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
test_utils
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
BaseTrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_joint_trainer
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
test_utils
.
MockFooTask
(
params
=
test_utils
.
FooConfig
(),
name
=
"foo"
),
test_utils
.
MockBarTask
(
params
=
test_utils
.
BarConfig
(),
name
=
"bar"
)
]
task_weights
=
{
"foo"
:
1.0
,
"bar"
:
1.0
}
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
,
task_weights
=
task_weights
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
test_trainer
=
base_trainer
.
MultiTaskBaseTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
def
test_trainer_with_configs
(
self
):
config
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
"foo"
,
task_config
=
test_utils
.
FooConfig
(),
task_weight
=
0.5
),
configs
.
TaskRoutine
(
task_name
=
"bar"
,
task_config
=
test_utils
.
BarConfig
(),
task_weight
=
0.5
)))
test_multitask
=
multitask
.
MultiTask
.
from_config
(
config
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
test_trainer
=
base_trainer
.
MultiTaskBaseTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
self
.
assertEqual
(
test_multitask
.
task_weight
(
"foo"
),
0.5
)
self
.
assertEqual
(
test_trainer
.
global_step
.
numpy
(),
5
)
self
.
assertIn
(
"learning_rate"
,
results
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/modeling/multitask/configs.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration definitions for multi-task training."""
from
typing
import
Optional
,
Tuple
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
hyperparams
@
dataclasses
.
dataclass
class
TaskRoutine
(
hyperparams
.
Config
):
task_name
:
str
=
""
task_config
:
cfg
.
TaskConfig
=
None
eval_steps
:
Optional
[
int
]
=
None
task_weight
:
Optional
[
float
]
=
1.0
@
dataclasses
.
dataclass
class
MultiTaskConfig
(
hyperparams
.
Config
):
init_checkpoint
:
str
=
""
model
:
hyperparams
.
Config
=
None
task_routines
:
Tuple
[
TaskRoutine
,
...]
=
()
@
dataclasses
.
dataclass
class
ProportionalSampleConfig
(
hyperparams
.
Config
):
alpha
:
float
=
1.0
@
dataclasses
.
dataclass
class
AnnealingSampleConfig
(
hyperparams
.
Config
):
steps_per_epoch
:
int
=
5
total_steps
:
int
=
20
@
dataclasses
.
dataclass
class
TaskSamplingConfig
(
hyperparams
.
OneOfConfig
):
type
:
str
=
""
uniform
:
hyperparams
.
Config
=
hyperparams
.
Config
()
proportional
:
ProportionalSampleConfig
=
ProportionalSampleConfig
()
annealing
:
AnnealingSampleConfig
=
AnnealingSampleConfig
()
@
dataclasses
.
dataclass
class
MultiTaskTrainerConfig
(
cfg
.
TrainerConfig
):
trainer_type
:
str
=
"interleaving"
task_sampler
:
TaskSamplingConfig
=
TaskSamplingConfig
(
type
=
"proportional"
)
@
dataclasses
.
dataclass
class
MultiTaskExperimentConfig
(
hyperparams
.
Config
):
"""An experiment config for multi-task training and multi-task evaluation."""
task
:
MultiTaskConfig
=
MultiTaskConfig
()
trainer
:
MultiTaskTrainerConfig
=
MultiTaskTrainerConfig
()
runtime
:
cfg
.
RuntimeConfig
=
cfg
.
RuntimeConfig
()
@
dataclasses
.
dataclass
class
MultiEvalExperimentConfig
(
cfg
.
ExperimentConfig
):
"""An experiment config for single-task training and multi-task evaluation.
Attributes:
eval_tasks: individual evaluation tasks.
"""
eval_tasks
:
MultiTaskConfig
=
MultiTaskConfig
()
official/modeling/multitask/evaluator.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask Evaluator implementation.
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
from
typing
import
Optional
,
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.core
import
train_utils
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
multitask
@
gin
.
configurable
class
MultiTaskEvaluator
(
orbit
.
AbstractEvaluator
):
"""Implements the common trainer shared for TensorFlow models."""
def
__init__
(
self
,
task
:
multitask
.
MultiTask
,
model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
global_step
:
Optional
[
tf
.
Variable
]
=
None
,
checkpoint_exporter
:
Optional
[
train_utils
.
BestCheckpointExporter
]
=
None
):
"""Initialize common trainer for TensorFlow models.
Args:
task: A multitask.MultiTask instance.
model: tf.keras.Model instance.
global_step: the global step variable.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_task
=
task
self
.
_model
=
model
self
.
_global_step
=
global_step
or
orbit
.
utils
.
create_global_step
()
self
.
_checkpoint_exporter
=
checkpoint_exporter
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
model
=
self
.
model
)
self
.
_validation_losses
=
None
self
.
_validation_metrics
=
None
# Builds per-task datasets.
self
.
eval_datasets
=
{}
for
name
,
task
in
self
.
task
.
tasks
.
items
():
self
.
eval_datasets
[
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
validation_data
)
# Builds per-task validation loops.
def
get_function
(
task_name
,
task
):
task_metrics
=
self
.
validation_metrics
[
task_name
]
task_loss
=
self
.
validation_losses
[
task_name
]
if
isinstance
(
self
.
model
,
base_model
.
MultiTaskBaseModel
):
model
=
self
.
model
.
sub_tasks
[
task_name
]
else
:
model
=
self
.
model
def
step_fn
(
inputs
):
logs
=
task
.
validation_step
(
inputs
,
model
=
model
,
metrics
=
task_metrics
)
task_loss
.
update_state
(
logs
[
task
.
loss
])
return
logs
@
tf
.
function
def
eval_step_fn
(
iterator
):
distributed_outputs
=
self
.
strategy
.
run
(
step_fn
,
args
=
(
next
(
iterator
),))
return
tf
.
nest
.
map_structure
(
self
.
strategy
.
experimental_local_results
,
distributed_outputs
)
return
orbit
.
utils
.
create_loop_fn
(
eval_step_fn
)
self
.
task_fns
=
{
name
:
get_function
(
name
,
task
)
for
name
,
task
in
self
.
task
.
tasks
.
items
()
}
@
property
def
strategy
(
self
):
return
self
.
_strategy
@
property
def
task
(
self
):
return
self
.
_task
@
property
def
model
(
self
):
return
self
.
_model
@
property
def
global_step
(
self
):
return
self
.
_global_step
@
property
def
validation_losses
(
self
):
"""Accesses the validation loss metric object."""
if
self
.
_validation_losses
is
None
:
# Builds the per-task metrics and losses.
self
.
_validation_losses
=
{}
for
name
in
self
.
task
.
tasks
:
self
.
_validation_losses
[
name
]
=
tf
.
keras
.
metrics
.
Mean
(
"validation_loss"
,
dtype
=
tf
.
float32
)
return
self
.
_validation_losses
@
property
def
validation_metrics
(
self
):
"""Accesses all validation metric metric objects."""
if
self
.
_validation_metrics
is
None
:
# Builds the per-task metrics and losses.
self
.
_validation_metrics
=
{}
for
name
,
task
in
self
.
task
.
tasks
.
items
():
self
.
_validation_metrics
[
name
]
=
task
.
build_metrics
(
training
=
False
)
return
self
.
_validation_metrics
@
property
def
checkpoint
(
self
):
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
def
evaluate
(
self
,
num_steps
:
tf
.
Tensor
):
"""Performs evaluation for each `EvalTask`."""
for
metric
in
self
.
validation_losses
.
values
():
metric
.
reset_states
()
for
metrics
in
self
.
validation_metrics
.
values
():
for
metric
in
metrics
:
metric
.
reset_states
()
results
=
{}
eval_iters
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_datasets
)
for
name
,
task_eval_loop
in
self
.
task_fns
.
items
():
outputs
=
None
eval_iter
=
eval_iters
[
name
]
task
=
self
.
task
.
tasks
[
name
]
task_eval_steps
=
self
.
task
.
task_eval_steps
(
name
)
or
num_steps
outputs
=
task_eval_loop
(
eval_iter
,
task_eval_steps
,
state
=
outputs
,
reduce_fn
=
task
.
aggregate_logs
)
task_metrics
=
self
.
validation_metrics
[
name
]
task_loss
=
self
.
validation_losses
[
name
]
logs
=
{}
for
metric
in
task_metrics
+
[
task_loss
]:
logs
[
metric
.
name
]
=
metric
.
result
()
if
outputs
:
metrics
=
task
.
reduce_aggregated_logs
(
outputs
,
global_step
=
self
.
global_step
)
logs
.
update
(
metrics
)
results
[
name
]
=
logs
if
self
.
_checkpoint_exporter
:
self
.
_checkpoint_exporter
.
maybe_export_checkpoint
(
self
.
checkpoint
,
results
,
self
.
global_step
.
numpy
())
return
results
official/modeling/multitask/evaluator_test.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.evaluator."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.modeling.multitask
import
evaluator
from
official.modeling.multitask
import
multitask
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
MockModel
(
tf
.
keras
.
Model
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
1
)
def
call
(
self
,
inputs
):
print
(
inputs
,
type
(
inputs
))
if
"y"
in
inputs
:
self
.
add_loss
(
tf
.
zeros
((
1
,),
dtype
=
tf
.
float32
))
else
:
self
.
add_loss
(
tf
.
ones
((
1
,),
dtype
=
tf
.
float32
))
return
self
.
dense
(
inputs
[
"x"
])
class
MockTask
(
base_task
.
Task
):
"""Mock task object for testing."""
def
build_metrics
(
self
,
training
:
bool
=
True
):
del
training
return
[
tf
.
keras
.
metrics
.
Accuracy
(
name
=
"acc"
)]
def
build_inputs
(
self
,
params
):
def
generate_data
(
_
):
x
=
tf
.
zeros
(
shape
=
(
2
,),
dtype
=
tf
.
float32
)
label
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
)
if
self
.
name
==
"bar"
:
return
dict
(
x
=
x
,
y
=
x
),
label
else
:
return
dict
(
x
=
x
),
label
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
generate_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
.
prefetch
(
buffer_size
=
1
).
batch
(
2
,
drop_remainder
=
True
)
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
logs
=
super
().
validation_step
(
inputs
,
model
,
metrics
)
logs
[
"counter"
]
=
tf
.
ones
((
1
,),
dtype
=
tf
.
float32
)
return
logs
def
aggregate_logs
(
self
,
state
,
step_outputs
):
if
state
is
None
:
state
=
{}
for
key
,
value
in
step_outputs
.
items
():
if
key
not
in
state
:
state
[
key
]
=
[]
state
[
key
].
append
(
np
.
concatenate
([
np
.
expand_dims
(
v
.
numpy
(),
axis
=
0
)
for
v
in
value
]))
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
for
k
,
v
in
aggregated_logs
.
items
():
aggregated_logs
[
k
]
=
np
.
sum
(
np
.
stack
(
v
,
axis
=
0
))
return
aggregated_logs
class
EvaluatorTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_evaluator
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
task
=
test_multitask
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"foo"
].
keys
())
self
.
assertEqual
(
results
[
"bar"
][
"validation_loss"
],
0.0
)
self
.
assertEqual
(
results
[
"foo"
][
"validation_loss"
],
1.0
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_evaluator_numpy_metrics
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
task
=
test_multitask
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
results
[
"bar"
][
"counter"
],
5.
*
distribution
.
num_replicas_in_sync
)
self
.
assertEqual
(
results
[
"foo"
][
"counter"
],
5.
*
distribution
.
num_replicas_in_sync
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/modeling/multitask/interleaving_trainer.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask trainer that interleaves each task's train step."""
from
typing
import
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
base_trainer
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
task_sampler
as
sampler
@
gin
.
configurable
class
MultiTaskInterleavingTrainer
(
base_trainer
.
MultiTaskBaseTrainer
):
"""MultiTask trainer that interleaves task update."""
def
__init__
(
self
,
multi_task
:
multitask
.
MultiTask
,
multi_task_model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
optimizer
:
tf
.
optimizers
.
Optimizer
,
task_sampler
:
sampler
.
TaskSampler
,
trainer_options
=
None
):
super
(
MultiTaskInterleavingTrainer
,
self
).
__init__
(
multi_task
=
multi_task
,
multi_task_model
=
multi_task_model
,
optimizer
=
optimizer
,
trainer_options
=
trainer_options
)
self
.
_task_sampler
=
task_sampler
# Build per task train step.
def
_get_task_step
(
task_name
,
task
):
def
step_fn
(
inputs
):
if
isinstance
(
self
.
multi_task_model
,
base_model
.
MultiTaskBaseModel
):
task_model
=
self
.
multi_task_model
.
sub_tasks
[
task_name
]
else
:
task_model
=
self
.
multi_task_model
task_logs
=
task
.
train_step
(
inputs
,
model
=
task_model
,
optimizer
=
self
.
optimizer
,
metrics
=
self
.
training_metrics
[
task_name
])
self
.
training_losses
[
task_name
].
update_state
(
task_logs
[
task
.
loss
])
return
step_fn
self
.
_task_train_step_map
=
{
name
:
_get_task_step
(
name
,
task
)
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
()
}
# TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
# on TensorBoard.
self
.
_task_step_counters
=
{
name
:
orbit
.
utils
.
create_global_step
()
for
name
in
self
.
multi_task
.
tasks
}
def
task_step_counter
(
self
,
name
):
return
self
.
_task_step_counters
[
name
]
def
train_step
(
self
,
iterator_map
):
# Sample one task to train according to a multinomial distribution
rn
=
tf
.
random
.
stateless_uniform
(
shape
=
[],
seed
=
(
0
,
self
.
global_step
))
cumulative_sample_distribution
=
self
.
_task_sampler
.
task_cumulative_distribution
(
self
.
global_step
)
# Prepend a [0.0] for indexing convenience.
cumulative_sample_distribution
=
tf
.
concat
(
[
tf
.
constant
([
0.0
],
dtype
=
tf
.
float32
),
cumulative_sample_distribution
],
axis
=
0
)
for
idx
,
(
name
,
_
)
in
enumerate
(
self
.
multi_task
.
tasks
.
items
()):
begin
=
cumulative_sample_distribution
[
idx
]
end
=
cumulative_sample_distribution
[
idx
+
1
]
if
rn
>=
begin
and
rn
<
end
:
self
.
_strategy
.
run
(
self
.
_task_train_step_map
[
name
],
args
=
(
next
(
iterator_map
[
name
]),))
self
.
global_step
.
assign_add
(
1
)
self
.
task_step_counter
(
name
).
assign_add
(
1
)
official/modeling/multitask/interleaving_trainer_test.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.interleaving_trainer."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
interleaving_trainer
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
task_sampler
from
official.modeling.multitask
import
test_utils
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
InterleavingTrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_interleaving_trainer
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
test_utils
.
MockFooTask
(
params
=
test_utils
.
FooConfig
(),
name
=
"foo"
),
test_utils
.
MockBarTask
(
params
=
test_utils
.
BarConfig
(),
name
=
"bar"
)
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
sampler
=
task_sampler
.
UniformTaskSampler
(
task_weights
=
test_multitask
.
task_weights
)
test_trainer
=
interleaving_trainer
.
MultiTaskInterleavingTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
,
task_sampler
=
sampler
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_with_configs
(
self
,
distribution
):
config
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
"foo"
,
task_config
=
test_utils
.
FooConfig
(),
task_weight
=
3.0
),
configs
.
TaskRoutine
(
task_name
=
"bar"
,
task_config
=
test_utils
.
BarConfig
(),
task_weight
=
1.0
)))
with
distribution
.
scope
():
test_multitask
=
multitask
.
MultiTask
.
from_config
(
config
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
num_step
=
1000
sampler
=
task_sampler
.
AnnealingTaskSampler
(
task_weights
=
test_multitask
.
task_weights
,
steps_per_epoch
=
num_step
/
5
,
total_steps
=
num_step
)
test_trainer
=
interleaving_trainer
.
MultiTaskInterleavingTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
,
task_sampler
=
sampler
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
num_step
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
self
.
assertEqual
(
test_trainer
.
global_step
.
numpy
(),
num_step
)
bar_sampled_step
=
test_trainer
.
task_step_counter
(
"bar"
).
numpy
()
foo_sampled_step
=
test_trainer
.
task_step_counter
(
"foo"
).
numpy
()
self
.
assertEqual
(
bar_sampled_step
+
foo_sampled_step
,
num_step
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/modeling/multitask/multitask.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental MultiTask base class for multi-task training/evaluation."""
import
abc
from
typing
import
Dict
,
List
,
Optional
,
Text
,
Union
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
configs
OptimizationConfig
=
optimization
.
OptimizationConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
class
MultiTask
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""A multi-task class to manage multiple tasks."""
def
__init__
(
self
,
tasks
:
Union
[
Dict
[
Text
,
base_task
.
Task
],
List
[
base_task
.
Task
]],
task_weights
:
Optional
[
Dict
[
str
,
Union
[
float
,
int
]]]
=
None
,
task_eval_steps
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
name
:
Optional
[
str
]
=
None
):
"""MultiTask initialization.
Args:
tasks: a list or a flat dict of Task.
task_weights: a dict of (task, task weight), task weight can be applied
directly during loss summation in a joint backward step, or it can be
used to sample task among interleaved backward step.
task_eval_steps: a dict of (task, eval steps).
name: the instance name of a MultiTask object.
"""
super
().
__init__
(
name
=
name
)
if
isinstance
(
tasks
,
list
):
self
.
_tasks
=
{}
for
task
in
tasks
:
if
task
.
name
in
self
.
_tasks
:
raise
ValueError
(
"Duplicated tasks found, task.name is %s"
%
task
.
name
)
self
.
_tasks
[
task
.
name
]
=
task
elif
isinstance
(
tasks
,
dict
):
self
.
_tasks
=
tasks
else
:
raise
ValueError
(
"The tasks argument has an invalid type: %s"
%
type
(
tasks
))
self
.
_task_eval_steps
=
task_eval_steps
or
{}
self
.
_task_eval_steps
=
dict
([
(
name
,
self
.
_task_eval_steps
.
get
(
name
,
None
))
for
name
in
self
.
tasks
])
self
.
_task_weights
=
task_weights
or
{}
self
.
_task_weights
=
dict
([
(
name
,
self
.
_task_weights
.
get
(
name
,
1.0
))
for
name
in
self
.
tasks
])
@
classmethod
def
from_config
(
cls
,
config
:
configs
.
MultiTaskConfig
,
logging_dir
=
None
):
tasks
=
{}
task_eval_steps
=
{}
task_weights
=
{}
for
task_routine
in
config
.
task_routines
:
task_name
=
task_routine
.
task_name
tasks
[
task_name
]
=
task_factory
.
get_task
(
task_routine
.
task_config
,
logging_dir
=
logging_dir
)
task_eval_steps
[
task_name
]
=
task_routine
.
eval_steps
task_weights
[
task_name
]
=
task_routine
.
task_weight
return
cls
(
tasks
,
task_eval_steps
=
task_eval_steps
,
task_weights
=
task_weights
)
@
property
def
tasks
(
self
):
return
self
.
_tasks
def
task_eval_steps
(
self
,
task_name
):
return
self
.
_task_eval_steps
[
task_name
]
def
task_weight
(
self
,
task_name
):
return
self
.
_task_weights
[
task_name
]
@
property
def
task_weights
(
self
):
return
self
.
_task_weights
@
classmethod
def
create_optimizer
(
cls
,
optimizer_config
:
OptimizationConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
return
base_task
.
Task
.
create_optimizer
(
optimizer_config
=
optimizer_config
,
runtime_config
=
runtime_config
)
def
joint_train_step
(
self
,
task_inputs
,
multi_task_model
:
base_model
.
MultiTaskBaseModel
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
task_metrics
):
"""The joint train step.
Args:
task_inputs: a dictionary of task names and per-task features.
multi_task_model: a MultiTaskBaseModel instance.
optimizer: a tf.optimizers.Optimizer.
task_metrics: a dictionary of task names and per-task metrics.
Returns:
A dictionary of losses, inculding per-task losses and their weighted sum.
"""
losses
=
{}
with
tf
.
GradientTape
()
as
tape
:
total_loss
=
0.0
for
name
,
model
in
multi_task_model
.
sub_tasks
.
items
():
inputs
=
task_inputs
[
name
]
if
isinstance
(
inputs
,
tuple
)
and
len
(
inputs
)
==
2
:
features
,
labels
=
inputs
elif
isinstance
(
inputs
,
dict
):
features
,
labels
=
inputs
,
inputs
else
:
raise
ValueError
(
"The iterator output is neither a tuple nor a "
"dictionary. It is not implemented to support "
"such outputs."
)
outputs
=
model
(
features
,
training
=
True
)
task_loss
=
self
.
tasks
[
name
].
build_losses
(
labels
,
outputs
)
task_weight
=
self
.
task_weight
(
name
)
total_loss
+=
task_weight
*
task_loss
losses
[
name
]
=
task_loss
self
.
tasks
[
name
].
process_metrics
(
task_metrics
[
name
],
labels
,
outputs
)
# Scales loss as the default gradients allreduce performs sum inside
# the optimizer.
scaled_loss
=
total_loss
/
tf
.
distribute
.
get_strategy
(
).
num_replicas_in_sync
tvars
=
multi_task_model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
losses
[
"total_loss"
]
=
total_loss
return
losses
official/modeling/multitask/task_sampler.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils to sample tasks for interleaved optimization."""
import
abc
from
typing
import
Union
,
Dict
,
Text
import
tensorflow
as
tf
from
official.modeling.multitask
import
configs
class
TaskSampler
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""An abstract class defining task sampling API for interleaving trainer."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]]):
self
.
_task_weights
=
task_weights
@
abc
.
abstractmethod
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Compute cumulative distribution to sample tasks.
It calculates the cumulative distribution of the multinomial task
distribution with respect to which to be sampled against.
Args:
global_step: A tensor indicating current progess of training.
Returns:
A float tensor with shape (#(task), 1) that represents the cumulative
sampling distribution.
"""
pass
class
UniformTaskSampler
(
TaskSampler
):
"""Sample all tasks uniformly."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]]):
super
(
UniformTaskSampler
,
self
).
__init__
(
task_weights
=
task_weights
)
self
.
_uniform_cumulative
=
tf
.
math
.
cumsum
(
tf
.
constant
(
[
1.0
/
len
(
self
.
_task_weights
)]
*
len
(
self
.
_task_weights
),
dtype
=
tf
.
float32
))
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
del
global_step
return
self
.
_uniform_cumulative
class
ProportionalTaskSampler
(
TaskSampler
):
"""Sample tasks proportional to task weights."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]],
alpha
:
float
=
1.0
):
super
(
ProportionalTaskSampler
,
self
).
__init__
(
task_weights
=
task_weights
)
self
.
_alpha
=
tf
.
cast
(
alpha
,
dtype
=
tf
.
float32
)
task_weight_dict_ordered_list
=
tf
.
constant
(
[
weight
for
_
,
weight
in
self
.
_task_weights
.
items
()],
dtype
=
tf
.
float32
)
task_sizes
=
tf
.
math
.
pow
(
task_weight_dict_ordered_list
,
self
.
_alpha
)
task_distribution
=
task_sizes
/
tf
.
reduce_sum
(
task_sizes
)
self
.
_porportional_cumulative
=
tf
.
math
.
cumsum
(
task_distribution
)
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
del
global_step
return
self
.
_porportional_cumulative
class
AnnealingTaskSampler
(
TaskSampler
):
"""Sample tasks according to task weights as well as training progress."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]],
steps_per_epoch
:
int
,
total_steps
:
int
):
super
(
AnnealingTaskSampler
,
self
).
__init__
(
task_weights
=
task_weights
)
self
.
_steps_per_epoch
=
tf
.
cast
(
steps_per_epoch
,
dtype
=
tf
.
float32
)
self
.
_total_epochs
=
tf
.
cast
(
total_steps
/
self
.
_steps_per_epoch
,
dtype
=
tf
.
float32
)
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
cur_epoch
=
tf
.
math
.
floor
(
tf
.
cast
(
global_step
,
dtype
=
tf
.
float32
)
/
self
.
_steps_per_epoch
)
alpha
=
1.0
-
0.8
*
(
cur_epoch
-
1
)
/
(
self
.
_total_epochs
-
1
+
1e-10
)
task_weight_dict_ordered_list
=
[
weight
for
_
,
weight
in
self
.
_task_weights
.
items
()
]
task_sizes
=
tf
.
math
.
pow
(
tf
.
constant
(
task_weight_dict_ordered_list
,
dtype
=
tf
.
float32
),
tf
.
cast
(
alpha
,
dtype
=
tf
.
float32
))
dynamic_task_distribution
=
task_sizes
/
tf
.
reduce_sum
(
task_sizes
)
return
tf
.
math
.
cumsum
(
dynamic_task_distribution
)
def
get_task_sampler
(
config
:
configs
.
TaskSamplingConfig
,
task_weights
:
Dict
[
Text
,
float
])
->
TaskSampler
:
"""Utils to create task sampler with configuration and task weights."""
oneof_config
=
config
.
get
()
if
config
.
type
==
'uniform'
:
return
UniformTaskSampler
(
task_weights
=
task_weights
)
elif
config
.
type
==
'proportional'
:
return
ProportionalTaskSampler
(
task_weights
=
task_weights
,
alpha
=
oneof_config
.
alpha
)
elif
config
.
type
==
'annealing'
:
return
AnnealingTaskSampler
(
task_weights
=
task_weights
,
steps_per_epoch
=
oneof_config
.
steps_per_epoch
,
total_steps
=
oneof_config
.
total_steps
)
else
:
raise
RuntimeError
(
'Task sampler type not supported'
)
official/modeling/multitask/task_sampler_test.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.task_sampler."""
import
tensorflow
as
tf
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
task_sampler
as
sampler
class
TaskSamplerTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
TaskSamplerTest
,
self
).
setUp
()
self
.
_task_weights
=
{
'A'
:
1.0
,
'B'
:
2.0
,
'C'
:
3.0
}
def
test_uniform_sample_distribution
(
self
):
uniform_sampler
=
sampler
.
get_task_sampler
(
configs
.
TaskSamplingConfig
(
type
=
'uniform'
),
self
.
_task_weights
)
for
step
in
range
(
5
):
cumulative_distribution
=
uniform_sampler
.
task_cumulative_distribution
(
tf
.
constant
(
step
,
dtype
=
tf
.
int64
))
self
.
assertAllClose
([
0.333333
,
0.666666
,
1.0
],
cumulative_distribution
.
numpy
())
def
test_proportional_sample_distribution
(
self
):
prop_sampler
=
sampler
.
get_task_sampler
(
configs
.
TaskSamplingConfig
(
type
=
'proportional'
,
proportional
=
configs
.
ProportionalSampleConfig
(
alpha
=
2.0
)),
self
.
_task_weights
)
# CucmulativeOf(Normalize([1.0^2, 2.0^2, 3.0^2]))
for
step
in
range
(
5
):
cumulative_distribution
=
prop_sampler
.
task_cumulative_distribution
(
tf
.
constant
(
step
,
dtype
=
tf
.
int64
))
self
.
assertAllClose
([
0.07142857
,
0.35714286
,
1.0
],
cumulative_distribution
.
numpy
())
def
test_annealing_sample_distribution
(
self
):
num_epoch
=
3
step_per_epoch
=
6
annel_sampler
=
sampler
.
get_task_sampler
(
configs
.
TaskSamplingConfig
(
type
=
'annealing'
,
annealing
=
configs
.
AnnealingSampleConfig
(
steps_per_epoch
=
step_per_epoch
,
total_steps
=
step_per_epoch
*
num_epoch
)),
self
.
_task_weights
)
global_step
=
tf
.
Variable
(
0
,
dtype
=
tf
.
int64
,
name
=
'global_step'
,
trainable
=
False
)
expected_cumulative_epochs
=
[[
0.12056106
,
0.4387236
,
1.0
],
[
0.16666667
,
0.5
,
1.0
],
[
0.22477472
,
0.5654695
,
1.0
]]
for
epoch
in
range
(
num_epoch
):
for
_
in
range
(
step_per_epoch
):
cumulative_distribution
=
annel_sampler
.
task_cumulative_distribution
(
tf
.
constant
(
global_step
,
dtype
=
tf
.
int64
))
global_step
.
assign_add
(
1
)
self
.
assertAllClose
(
expected_cumulative_epochs
[
epoch
],
cumulative_distribution
.
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Prev
1
2
3
4
5
6
7
8
9
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment