Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f16a7b5b
Unverified
Commit
f16a7b5b
authored
May 04, 2021
by
vedanshu
Committed by
GitHub
May 04, 2021
Browse files
Merge pull request
#1
from tensorflow/master
new pull
parents
8e9296ff
8f58f396
Changes
298
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1591 additions
and
304 deletions
+1591
-304
official/modeling/hyperparams/__init__.py
official/modeling/hyperparams/__init__.py
+2
-3
official/modeling/hyperparams/base_config.py
official/modeling/hyperparams/base_config.py
+41
-19
official/modeling/hyperparams/base_config_test.py
official/modeling/hyperparams/base_config_test.py
+68
-7
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+10
-178
official/modeling/hyperparams/oneof.py
official/modeling/hyperparams/oneof.py
+6
-11
official/modeling/hyperparams/oneof_test.py
official/modeling/hyperparams/oneof_test.py
+13
-9
official/modeling/hyperparams/params_dict.py
official/modeling/hyperparams/params_dict.py
+60
-35
official/modeling/hyperparams/params_dict_test.py
official/modeling/hyperparams/params_dict_test.py
+125
-42
official/modeling/multitask/__init__.py
official/modeling/multitask/__init__.py
+14
-0
official/modeling/multitask/base_model.py
official/modeling/multitask/base_model.py
+60
-0
official/modeling/multitask/base_trainer.py
official/modeling/multitask/base_trainer.py
+176
-0
official/modeling/multitask/base_trainer_test.py
official/modeling/multitask/base_trainer_test.py
+90
-0
official/modeling/multitask/configs.py
official/modeling/multitask/configs.py
+79
-0
official/modeling/multitask/evaluator.py
official/modeling/multitask/evaluator.py
+172
-0
official/modeling/multitask/evaluator_test.py
official/modeling/multitask/evaluator_test.py
+138
-0
official/modeling/multitask/interleaving_trainer.py
official/modeling/multitask/interleaving_trainer.py
+92
-0
official/modeling/multitask/interleaving_trainer_test.py
official/modeling/multitask/interleaving_trainer_test.py
+101
-0
official/modeling/multitask/multitask.py
official/modeling/multitask/multitask.py
+148
-0
official/modeling/multitask/task_sampler.py
official/modeling/multitask/task_sampler.py
+121
-0
official/modeling/multitask/task_sampler_test.py
official/modeling/multitask/task_sampler_test.py
+75
-0
No files found.
Too many changes to show.
To preserve performance only
298 of 298+
files are displayed.
Plain diff
Email patch
official/modeling/hyperparams/__init__.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +11,7 @@
...
@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Hyperparams package definition."""
"""Hyperparams package definition."""
# pylint: disable=g-multiple-import
# pylint: disable=g-multiple-import
from
official.modeling.hyperparams.base_config
import
*
from
official.modeling.hyperparams.base_config
import
*
...
...
official/modeling/hyperparams/base_config.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,17 +11,13 @@
...
@@ -12,17 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Base configurations to standardize experiments."""
from
__future__
import
absolute_import
"""Base configurations to standardize experiments."""
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
copy
import
copy
import
functools
import
functools
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Type
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Type
from
absl
import
logging
import
dataclasses
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -35,11 +30,15 @@ from official.modeling.hyperparams import params_dict
...
@@ -35,11 +30,15 @@ from official.modeling.hyperparams import params_dict
class
Config
(
params_dict
.
ParamsDict
):
class
Config
(
params_dict
.
ParamsDict
):
"""The base configuration class that supports YAML/JSON based overrides.
"""The base configuration class that supports YAML/JSON based overrides.
* It recursively enforces a whitelist of basic types and container types, so
Because of YAML/JSON serialization limitations, some semantics of dataclass
are not supported:
* It recursively enforces a allowlist of basic types and container types, so
it avoids surprises with copy and reuse caused by unanticipated types.
it avoids surprises with copy and reuse caused by unanticipated types.
*
I
t converts
d
ict to Config even within sequences,
*
Warning: i
t converts
D
ict to
`
Config
`
even within sequences,
e.g. for config = Config({'key': [([{'a': 42}],)]),
e.g. for config = Config({'key': [([{'a': 42}],)]),
type(config.key[0][0][0]) is Config rather than dict.
type(config.key[0][0][0]) is Config rather than dict.
If you define/annotate some field as Dict, the field will convert to a
`Config` instance and lose the dictionary type.
"""
"""
# It's safe to add bytes and other immutable types here.
# It's safe to add bytes and other immutable types here.
...
@@ -142,10 +141,11 @@ class Config(params_dict.ParamsDict):
...
@@ -142,10 +141,11 @@ class Config(params_dict.ParamsDict):
return
subconfig_type
return
subconfig_type
def
__post_init__
(
self
,
default_params
,
restrictions
,
*
args
,
**
kwargs
):
def
__post_init__
(
self
,
default_params
,
restrictions
,
*
args
,
**
kwargs
):
super
().
__init__
(
default_params
=
default_params
,
super
().
__init__
(
restrictions
=
restrictions
,
default_params
=
default_params
,
*
args
,
restrictions
=
restrictions
,
**
kwargs
)
*
args
,
**
kwargs
)
def
_set
(
self
,
k
,
v
):
def
_set
(
self
,
k
,
v
):
"""Overrides same method in ParamsDict.
"""Overrides same method in ParamsDict.
...
@@ -160,13 +160,32 @@ class Config(params_dict.ParamsDict):
...
@@ -160,13 +160,32 @@ class Config(params_dict.ParamsDict):
RuntimeError
RuntimeError
"""
"""
subconfig_type
=
self
.
_get_subconfig_type
(
k
)
subconfig_type
=
self
.
_get_subconfig_type
(
k
)
if
isinstance
(
v
,
dict
):
def
is_null
(
k
):
if
k
not
in
self
.
__dict__
or
not
self
.
__dict__
[
k
]:
if
k
not
in
self
.
__dict__
or
not
self
.
__dict__
[
k
]:
return
True
return
False
if
isinstance
(
v
,
dict
):
if
is_null
(
k
):
# If the key not exist or the value is None, a new Config-family object
# If the key not exist or the value is None, a new Config-family object
# sould be created for the key.
# sould be created for the key.
self
.
__dict__
[
k
]
=
subconfig_type
(
v
)
self
.
__dict__
[
k
]
=
subconfig_type
(
v
)
else
:
else
:
self
.
__dict__
[
k
].
override
(
v
)
self
.
__dict__
[
k
].
override
(
v
)
elif
not
is_null
(
k
)
and
isinstance
(
v
,
self
.
SEQUENCE_TYPES
)
and
all
(
[
not
isinstance
(
e
,
self
.
IMMUTABLE_TYPES
)
for
e
in
v
]):
if
len
(
self
.
__dict__
[
k
])
==
len
(
v
):
for
i
in
range
(
len
(
v
)):
self
.
__dict__
[
k
][
i
].
override
(
v
[
i
])
elif
not
all
([
isinstance
(
e
,
self
.
IMMUTABLE_TYPES
)
for
e
in
v
]):
logging
.
warning
(
"The list/tuple don't match the value dictionaries provided. Thus, "
'the list/tuple is determined by the type annotation and '
'values provided. This is error-prone.'
)
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
else
:
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
else
:
else
:
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
...
@@ -220,16 +239,19 @@ class Config(params_dict.ParamsDict):
...
@@ -220,16 +239,19 @@ class Config(params_dict.ParamsDict):
}
}
def
replace
(
self
,
**
kwargs
):
def
replace
(
self
,
**
kwargs
):
"""Like `override`, but returns a copy with the current config unchanged."""
"""Overrides/returns a unlocked copy with the current config unchanged."""
params
=
self
.
__class__
(
self
)
# pylint: disable=protected-access
params
.
override
(
kwargs
,
is_strict
=
True
)
params
=
copy
.
deepcopy
(
self
)
params
.
_locked
=
False
params
.
_override
(
kwargs
,
is_strict
=
True
)
# pylint: enable=protected-access
return
params
return
params
@
classmethod
@
classmethod
def
from_yaml
(
cls
,
file_path
:
str
):
def
from_yaml
(
cls
,
file_path
:
str
):
# Note: This only works if the Config has all default values.
# Note: This only works if the Config has all default values.
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'r'
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'r'
)
as
f
:
loaded
=
yaml
.
load
(
f
)
loaded
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
config
=
cls
()
config
=
cls
()
config
.
override
(
loaded
)
config
.
override
(
loaded
)
return
config
return
config
...
...
official/modeling/hyperparams/base_config_test.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +11,6 @@
...
@@ -12,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
import
pprint
import
pprint
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
...
@@ -45,6 +43,17 @@ class DumpConfig3(DumpConfig2):
...
@@ -45,6 +43,17 @@ class DumpConfig3(DumpConfig2):
g
:
Tuple
[
DumpConfig1
,
...]
=
(
DumpConfig1
(),)
g
:
Tuple
[
DumpConfig1
,
...]
=
(
DumpConfig1
(),)
@
dataclasses
.
dataclass
class
DumpConfig4
(
DumpConfig2
):
x
:
int
=
3
@
dataclasses
.
dataclass
class
DummyConfig5
(
base_config
.
Config
):
y
:
Tuple
[
DumpConfig2
,
...]
=
(
DumpConfig2
(),
DumpConfig4
())
z
:
Tuple
[
str
]
=
(
'a'
,)
class
BaseConfigTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
BaseConfigTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
assertHasSameTypes
(
self
,
c
,
d
,
msg
=
''
):
def
assertHasSameTypes
(
self
,
c
,
d
,
msg
=
''
):
...
@@ -106,6 +115,22 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -106,6 +115,22 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertEqual
(
config
.
g
[
0
].
a
,
4
)
self
.
assertEqual
(
config
.
g
[
0
].
a
,
4
)
self
.
assertEqual
(
config
.
g
[
0
].
b
,
'new text 3'
)
self
.
assertEqual
(
config
.
g
[
0
].
b
,
'new text 3'
)
def
test_replace
(
self
):
config
=
DumpConfig2
()
new_config
=
config
.
replace
(
e
=
{
'a'
:
2
})
self
.
assertEqual
(
new_config
.
e
.
a
,
2
)
self
.
assertIsInstance
(
new_config
.
e
,
DumpConfig1
)
config
=
DumpConfig2
(
e
=
DumpConfig2
())
new_config
=
config
.
replace
(
e
=
{
'c'
:
4
})
self
.
assertEqual
(
new_config
.
e
.
c
,
4
)
self
.
assertIsInstance
(
new_config
.
e
,
DumpConfig2
)
config
=
DumpConfig3
()
new_config
=
config
.
replace
(
g
=
[{
'a'
:
4
,
'b'
:
'new text 3'
}])
self
.
assertIsInstance
(
new_config
.
g
[
0
],
DumpConfig1
)
self
.
assertEqual
(
new_config
.
g
[
0
].
a
,
4
)
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
(
'_locked'
,
"The key '_locked' is internally reserved."
),
(
'_locked'
,
"The key '_locked' is internally reserved."
),
(
'_restrictions'
,
"The key '_restrictions' is internally reserved."
),
(
'_restrictions'
,
"The key '_restrictions' is internally reserved."
),
...
@@ -147,10 +172,8 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -147,10 +172,8 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
config
=
base_config
.
Config
({
'key'
:
[{
'a'
:
42
}]})
config
=
base_config
.
Config
({
'key'
:
[{
'a'
:
42
}]})
config
.
override
({
'key'
:
[{
'b'
:
43
}]})
with
self
.
assertRaisesRegex
(
KeyError
,
"The key 'b' does not exist"
):
self
.
assertEqual
(
config
.
key
[
0
].
b
,
43
)
config
.
override
({
'key'
:
[{
'b'
:
43
}]})
with
self
.
assertRaisesRegex
(
AttributeError
,
'The key `a` does not exist'
):
_
=
config
.
key
[
0
].
a
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
(
lambda
x
:
x
,
'Unknown type'
),
(
lambda
x
:
x
,
'Unknown type'
),
...
@@ -294,6 +317,44 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -294,6 +317,44 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
]),
]),
"['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]"
)
"['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]"
)
def
test_with_restrictions
(
self
):
restrictions
=
[
'e.a<c'
]
config
=
DumpConfig2
(
restrictions
=
restrictions
)
config
.
validate
()
def
test_nested_tuple
(
self
):
config
=
DummyConfig5
()
config
.
override
({
'y'
:
[{
'c'
:
4
,
'd'
:
'new text 3'
,
'e'
:
{
'a'
:
2
}
},
{
'c'
:
0
,
'd'
:
'new text 3'
,
'e'
:
{
'a'
:
2
}
}],
'z'
:
[
'a'
,
'b'
,
'c'
],
})
self
.
assertEqual
(
config
.
y
[
0
].
c
,
4
)
self
.
assertEqual
(
config
.
y
[
1
].
c
,
0
)
self
.
assertIsInstance
(
config
.
y
[
0
],
DumpConfig2
)
self
.
assertIsInstance
(
config
.
y
[
1
],
DumpConfig4
)
self
.
assertSameElements
(
config
.
z
,
[
'a'
,
'b'
,
'c'
])
def
test_override_by_empty_sequence
(
self
):
config
=
DummyConfig5
()
config
.
override
({
'y'
:
[],
'z'
:
(),
},
is_strict
=
True
)
self
.
assertEmpty
(
config
.
y
)
self
.
assertEmpty
(
config
.
z
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/modeling/hyperparams/config_definitions.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,124 +11,18 @@
...
@@ -12,124 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Common configuration settings."""
from
typing
import
Optional
,
Union
"""Common configuration settings."""
# pylint:disable=wildcard-import
import
dataclasses
import
dataclasses
from
official.core.config_definitions
import
*
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.modeling.optimization.configs
import
optimization_config
from
official.utils
import
registry
OptimizationConfig
=
optimization_config
.
OptimizationConfig
@
dataclasses
.
dataclass
class
DataConfig
(
base_config
.
Config
):
"""The base configuration for building datasets.
Attributes:
input_path: The path to the input. It can be either (1) a file pattern, or
(2) multiple file patterns separated by comma. It should not be specified
when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
is required when above `tfds_name` is specified.
global_batch_size: The global batch size across all replicas.
is_training: Whether this data is used for training or not.
drop_remainder: Whether the last batch should be dropped in the case it has
fewer than `global_batch_size` elements.
shuffle_buffer_size: The buffer size used for shuffling training data.
cache: Whether to cache dataset examples. Can be used to avoid re-reading
from disk on the second epoch. Requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
sharding: Whether sharding is used in the input pipeline.
examples_consume: An `integer` specifying the number of examples it will
produce. If positive, it only takes this number of examples and raises
tf.error.OutOfRangeError after that. Default is -1, meaning it will
exhaust all the examples in the dataset.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
the returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default,
the returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped
for decoding when loading dataset from TFDS. Use comma to separate
multiple features. The main use case is to skip the image/video decoding
for better performance.
"""
input_path
:
str
=
""
tfds_name
:
str
=
""
tfds_split
:
str
=
""
global_batch_size
:
int
=
0
is_training
:
bool
=
None
drop_remainder
:
bool
=
True
shuffle_buffer_size
:
int
=
100
cache
:
bool
=
False
cycle_length
:
int
=
8
block_length
:
int
=
1
sharding
:
bool
=
True
examples_consume
:
int
=
-
1
tfds_data_dir
:
str
=
""
tfds_download
:
bool
=
False
tfds_as_supervised
:
bool
=
False
tfds_skip_decoding_feature
:
str
=
""
@
dataclasses
.
dataclass
class
RuntimeConfig
(
base_config
.
Config
):
"""High-level configurations for Runtime.
These include parameters that are not directly related to the experiment,
e.g. directories, accelerator type, etc.
Attributes:
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
enable_xla: Whether or not to enable XLA.
per_gpu_thread_count: thread count per GPU.
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
dataset_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
tpu: The address of the TPU to use, if any.
num_gpus: The number of GPUs to use, if any.
worker_hosts: comma-separated list of worker ip:port pairs for running
multi-worker models with DistributionStrategy.
task_index: If multi-worker training, the task index of this worker.
all_reduce_alg: Defines the algorithm for performing all-reduce.
num_packs: Sets `num_packs` in the cross device ops used in
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
'float16', or 'bfloat16'.
loss_scale: The type of loss scale, or 'float' value. This is used when
setting the mixed precision policy.
run_eagerly: Whether or not to run the experiment eagerly.
batchnorm_spatial_persistent: Whether or not to enable the spatial
persistent mode for CuDNN batch norm kernel for improved GPU performance.
"""
distribution_strategy
:
str
=
"mirrored"
enable_xla
:
bool
=
False
gpu_thread_mode
:
Optional
[
str
]
=
None
dataset_num_private_threads
:
Optional
[
int
]
=
None
per_gpu_thread_count
:
int
=
0
tpu
:
Optional
[
str
]
=
None
num_gpus
:
int
=
0
worker_hosts
:
Optional
[
str
]
=
None
task_index
:
int
=
-
1
all_reduce_alg
:
Optional
[
str
]
=
None
num_packs
:
int
=
1
mixed_precision_dtype
:
Optional
[
str
]
=
None
loss_scale
:
Optional
[
Union
[
str
,
float
]]
=
None
run_eagerly
:
bool
=
False
batchnorm_spatial_persistent
:
bool
=
False
# TODO(hongkuny): These configs are used in models that are going to deprecate.
# Once those models are removed, we should delete this file to avoid confusion.
# Users should not use this file anymore.
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TensorboardConfig
(
base_config
.
Config
):
class
TensorboardConfig
(
base_config
.
Config
):
"""Configuration for Tensorboard.
"""Configuration for Tensorboard.
...
@@ -151,75 +44,14 @@ class CallbacksConfig(base_config.Config):
...
@@ -151,75 +44,14 @@ class CallbacksConfig(base_config.Config):
Attributes:
Attributes:
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
enable_checkpoint_and_export: Whether or not to enable checkpoints as a
Callback. Defaults to True.
Callback. Defaults to True.
enable_backup_and_restore: Whether or not to add BackupAndRestore
callback. Defaults to True.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
enable_tensorboard: Whether or not to enable Tensorboard as a Callback.
Defaults to True.
Defaults to True.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
enable_time_history: Whether or not to enable TimeHistory Callbacks.
Defaults to True.
Defaults to True.
"""
"""
enable_checkpoint_and_export
:
bool
=
True
enable_checkpoint_and_export
:
bool
=
True
enable_backup_and_restore
:
bool
=
False
enable_tensorboard
:
bool
=
True
enable_tensorboard
:
bool
=
True
enable_time_history
:
bool
=
True
enable_time_history
:
bool
=
True
@
dataclasses
.
dataclass
class
TrainerConfig
(
base_config
.
Config
):
"""Configuration for trainer.
Attributes:
optimizer_config: optimizer config, it includes optimizer, learning rate,
and warmup schedule configs.
train_tf_while_loop: whether or not to use tf while loop.
train_tf_function: whether or not to use tf_function for training loop.
eval_tf_function: whether or not to use tf_function for eval.
steps_per_loop: number of steps per loop.
summary_interval: number of steps between each summary.
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
validation_interval: number of training steps to run between evaluations.
"""
optimizer_config
:
OptimizationConfig
=
OptimizationConfig
()
train_tf_while_loop
:
bool
=
True
train_tf_function
:
bool
=
True
eval_tf_function
:
bool
=
True
steps_per_loop
:
int
=
1000
summary_interval
:
int
=
1000
checkpoint_interval
:
int
=
1000
max_to_keep
:
int
=
5
continuous_eval_timeout
:
Optional
[
int
]
=
None
train_steps
:
int
=
0
validation_steps
:
Optional
[
int
]
=
None
validation_interval
:
int
=
1000
@
dataclasses
.
dataclass
class
TaskConfig
(
base_config
.
Config
):
model
:
base_config
.
Config
=
None
train_data
:
DataConfig
=
DataConfig
()
validation_data
:
DataConfig
=
DataConfig
()
@
dataclasses
.
dataclass
class
ExperimentConfig
(
base_config
.
Config
):
"""Top-level configuration."""
task
:
TaskConfig
=
TaskConfig
()
trainer
:
TrainerConfig
=
TrainerConfig
()
runtime
:
RuntimeConfig
=
RuntimeConfig
()
_REGISTERED_CONFIGS
=
{}
def
register_config_factory
(
name
):
"""Register ExperimentConfig factory method."""
return
registry
.
register
(
_REGISTERED_CONFIGS
,
name
)
def
get_exp_config_creater
(
exp_name
:
str
):
"""Looks up ExperimentConfig factory methods."""
exp_creater
=
registry
.
lookup
(
_REGISTERED_CONFIGS
,
exp_name
)
return
exp_creater
official/modeling/hyperparams/oneof.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +11,7 @@
...
@@ -12,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Config class that supports oneof functionality."""
"""Config class that supports oneof functionality."""
from
typing
import
Optional
from
typing
import
Optional
...
@@ -38,15 +37,12 @@ class OneOfConfig(base_config.Config):
...
@@ -38,15 +37,12 @@ class OneOfConfig(base_config.Config):
if
self
.
type
is
None
:
if
self
.
type
is
None
:
return
{
'type'
:
None
}
return
{
'type'
:
None
}
elif
self
.
__dict__
[
'type'
]
not
in
self
.
__dict__
:
elif
self
.
__dict__
[
'type'
]
not
in
self
.
__dict__
:
raise
ValueError
(
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
'type: {!r} is not a valid key!'
.
format
(
self
.
__dict__
[
'type'
]))
self
.
__dict__
[
'type'
]))
else
:
else
:
chosen_type
=
self
.
type
chosen_type
=
self
.
type
chosen_value
=
self
.
__dict__
[
chosen_type
]
chosen_value
=
self
.
__dict__
[
chosen_type
]
return
{
return
{
'type'
:
self
.
type
,
chosen_type
:
self
.
_export_config
(
chosen_value
)}
'type'
:
self
.
type
,
chosen_type
:
self
.
_export_config
(
chosen_value
)
}
def
get
(
self
):
def
get
(
self
):
"""Returns selected config based on the value of type.
"""Returns selected config based on the value of type.
...
@@ -57,6 +53,5 @@ class OneOfConfig(base_config.Config):
...
@@ -57,6 +53,5 @@ class OneOfConfig(base_config.Config):
if
chosen_type
is
None
:
if
chosen_type
is
None
:
return
None
return
None
if
chosen_type
not
in
self
.
__dict__
:
if
chosen_type
not
in
self
.
__dict__
:
raise
ValueError
(
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
type
))
'type: {!r} is not a valid key!'
.
format
(
self
.
type
))
return
self
.
__dict__
[
chosen_type
]
return
self
.
__dict__
[
chosen_type
]
official/modeling/hyperparams/oneof_test.py
View file @
f16a7b5b
# Lint as: python3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -12,7 +11,6 @@
...
@@ -12,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
import
dataclasses
import
dataclasses
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -48,12 +46,18 @@ class Network(base_config.Config):
...
@@ -48,12 +46,18 @@ class Network(base_config.Config):
class
OneOfTest
(
tf
.
test
.
TestCase
):
class
OneOfTest
(
tf
.
test
.
TestCase
):
def
test_to_dict
(
self
):
def
test_to_dict
(
self
):
network_params
=
{
'backbone'
:
{
'type'
:
'resnet'
,
network_params
=
{
'resnet'
:
{
'model_depth'
:
50
}
'backbone'
:
{
},
'type'
:
'resnet'
,
'output_layer'
:
{
'type'
:
'single'
,
'resnet'
:
{
'single'
:
1000
}
'model_depth'
:
50
}
}
},
'output_layer'
:
{
'type'
:
'single'
,
'single'
:
1000
}
}
network_config
=
Network
(
network_params
)
network_config
=
Network
(
network_params
)
self
.
assertEqual
(
network_config
.
as_dict
(),
network_params
)
self
.
assertEqual
(
network_config
.
as_dict
(),
network_params
)
...
...
official/modeling/hyperparams/params_dict.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,12 +11,8 @@
...
@@ -11,12 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""A parameter dictionary class which supports the nest structure."""
from
__future__
import
absolute_import
"""A parameter dictionary class which supports the nest structure."""
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
collections
import
copy
import
copy
...
@@ -30,7 +26,8 @@ import yaml
...
@@ -30,7 +26,8 @@ import yaml
# key-value pair string. It splits each k-v pair on the = sign, and
# key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single
# matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets.
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE
=
re
.
compile
(
r
"""
_PARAM_RE
=
re
.
compile
(
r
"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s*
\s*=\s*
((?P<val>\'(.*?)\' # single quote
((?P<val>\'(.*?)\' # single quote
...
@@ -44,6 +41,26 @@ _PARAM_RE = re.compile(r"""
...
@@ -44,6 +41,26 @@ _PARAM_RE = re.compile(r"""
_CONST_VALUE_RE
=
re
.
compile
(
r
'(\d.*|-\d.*|None)'
)
_CONST_VALUE_RE
=
re
.
compile
(
r
'(\d.*|-\d.*|None)'
)
# Yaml loader with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.
LOADER
=
yaml
.
SafeLoader
LOADER
.
add_implicit_resolver
(
'tag:yaml.org,2002:float'
,
re
.
compile
(
r
'''
^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$'''
,
re
.
X
),
list
(
'-+0123456789.'
))
class
ParamsDict
(
object
):
class
ParamsDict
(
object
):
"""A hyperparameter container class."""
"""A hyperparameter container class."""
...
@@ -72,7 +89,6 @@ class ParamsDict(object):
...
@@ -72,7 +89,6 @@ class ParamsDict(object):
if
default_params
is
None
:
if
default_params
is
None
:
default_params
=
{}
default_params
=
{}
self
.
override
(
default_params
,
is_strict
=
False
)
self
.
override
(
default_params
,
is_strict
=
False
)
self
.
validate
()
def
_set
(
self
,
k
,
v
):
def
_set
(
self
,
k
,
v
):
if
isinstance
(
v
,
dict
):
if
isinstance
(
v
,
dict
):
...
@@ -138,8 +154,8 @@ class ParamsDict(object):
...
@@ -138,8 +154,8 @@ class ParamsDict(object):
ValueError: if the ParamsDict instance has been locked.
ValueError: if the ParamsDict instance has been locked.
"""
"""
if
k
in
ParamsDict
.
RESERVED_ATTR
:
if
k
in
ParamsDict
.
RESERVED_ATTR
:
raise
AttributeError
(
'The key `{}` is reserved. No change is allowes. '
raise
AttributeError
(
.
format
(
k
))
'The key `{}` is reserved. No change is allowes. '
.
format
(
k
))
if
k
not
in
self
.
__dict__
.
keys
():
if
k
not
in
self
.
__dict__
.
keys
():
raise
AttributeError
(
'The key `{}` does not exist. '
.
format
(
k
))
raise
AttributeError
(
'The key `{}` does not exist. '
.
format
(
k
))
if
self
.
_locked
:
if
self
.
_locked
:
...
@@ -150,13 +166,13 @@ class ParamsDict(object):
...
@@ -150,13 +166,13 @@ class ParamsDict(object):
"""Override the ParamsDict with a set of given params.
"""Override the ParamsDict with a set of given params.
Args:
Args:
override_params: a dict or a ParamsDict specifying the parameters to
override_params: a dict or a ParamsDict specifying the parameters to
be
be
overridden.
overridden.
is_strict: a boolean specifying whether override is strict or not. If
is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict.
True, keys in `override_params` must be present in the ParamsDict.
If
If
False, keys in `override_params` can be different from what is
False, keys in `override_params` can be different from what is
currently
currently
defined in the ParamsDict. In this case, the ParamsDict will
defined in the ParamsDict. In this case, the ParamsDict will
be extended
be extended
to include the new keys.
to include the new keys.
"""
"""
if
self
.
_locked
:
if
self
.
_locked
:
raise
ValueError
(
'The ParamsDict has been locked. No change is allowed.'
)
raise
ValueError
(
'The ParamsDict has been locked. No change is allowed.'
)
...
@@ -230,7 +246,7 @@ class ParamsDict(object):
...
@@ -230,7 +246,7 @@ class ParamsDict(object):
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
What it enforces are:
What it enforces are:
- a.a1 = 1 == b.ccc.a1 =
2
- a.a1 = 1 == b.ccc.a1 =
1
- a.a2 = 2 <= b.bb.bb2 = 20
- a.a2 = 2 <= b.bb.bb2 = 20
Raises:
Raises:
...
@@ -240,6 +256,7 @@ class ParamsDict(object):
...
@@ -240,6 +256,7 @@ class ParamsDict(object):
(2) any inconsistency violating the restriction is found.
(2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported.
ValueError: if the restriction defined in the string is not supported.
"""
"""
def
_get_kv
(
dotted_string
,
params_dict
):
def
_get_kv
(
dotted_string
,
params_dict
):
"""Get keys and values indicated by dotted_string."""
"""Get keys and values indicated by dotted_string."""
if
_CONST_VALUE_RE
.
match
(
dotted_string
)
is
not
None
:
if
_CONST_VALUE_RE
.
match
(
dotted_string
)
is
not
None
:
...
@@ -270,56 +287,64 @@ class ParamsDict(object):
...
@@ -270,56 +287,64 @@ class ParamsDict(object):
tokens
=
restriction
.
split
(
'=='
)
tokens
=
restriction
.
split
(
'=='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
!=
right_v
:
if
left_v
!=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
raise
KeyError
(
.
format
(
tokens
[
0
],
tokens
[
1
]))
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'!='
in
restriction
:
elif
'!='
in
restriction
:
tokens
=
restriction
.
split
(
'!='
)
tokens
=
restriction
.
split
(
'!='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
==
right_v
:
if
left_v
==
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
raise
KeyError
(
.
format
(
tokens
[
0
],
tokens
[
1
]))
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<'
in
restriction
:
elif
'<'
in
restriction
:
tokens
=
restriction
.
split
(
'<'
)
tokens
=
restriction
.
split
(
'<'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>=
right_v
:
if
left_v
>=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
raise
KeyError
(
.
format
(
tokens
[
0
],
tokens
[
1
]))
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<='
in
restriction
:
elif
'<='
in
restriction
:
tokens
=
restriction
.
split
(
'<='
)
tokens
=
restriction
.
split
(
'<='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>
right_v
:
if
left_v
>
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
raise
KeyError
(
.
format
(
tokens
[
0
],
tokens
[
1
]))
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>'
in
restriction
:
elif
'>'
in
restriction
:
tokens
=
restriction
.
split
(
'>'
)
tokens
=
restriction
.
split
(
'>'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<=
right_v
:
if
left_v
<=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
raise
KeyError
(
.
format
(
tokens
[
0
],
tokens
[
1
]))
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>='
in
restriction
:
elif
'>='
in
restriction
:
tokens
=
restriction
.
split
(
'>='
)
tokens
=
restriction
.
split
(
'>='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<
right_v
:
if
left_v
<
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
raise
KeyError
(
.
format
(
tokens
[
0
],
tokens
[
1
]))
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
else
:
else
:
raise
ValueError
(
'Unsupported relation in restriction.'
)
raise
ValueError
(
'Unsupported relation in restriction.'
)
def
read_yaml_to_params_dict
(
file_path
):
def
read_yaml_to_params_dict
(
file_path
:
str
):
"""Reads a YAML file to a ParamsDict."""
"""Reads a YAML file to a ParamsDict."""
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'r'
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'r'
)
as
f
:
params_dict
=
yaml
.
load
(
f
)
params_dict
=
yaml
.
load
(
f
,
Loader
=
LOADER
)
return
ParamsDict
(
params_dict
)
return
ParamsDict
(
params_dict
)
def
save_params_dict_to_yaml
(
params
,
file_path
):
def
save_params_dict_to_yaml
(
params
,
file_path
):
"""Saves the input ParamsDict to a YAML file."""
"""Saves the input ParamsDict to a YAML file."""
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'w'
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'w'
)
as
f
:
def
_my_list_rep
(
dumper
,
data
):
def
_my_list_rep
(
dumper
,
data
):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return
dumper
.
represent_sequence
(
return
dumper
.
represent_sequence
(
u
'tag:yaml.org,2002:seq'
,
data
,
flow_style
=
True
)
u
'tag:yaml.org,2002:seq'
,
data
,
flow_style
=
True
)
yaml
.
add_representer
(
list
,
_my_list_rep
)
yaml
.
add_representer
(
list
,
_my_list_rep
)
yaml
.
dump
(
params
.
as_dict
(),
f
,
default_flow_style
=
False
)
yaml
.
dump
(
params
.
as_dict
(),
f
,
default_flow_style
=
False
)
...
@@ -408,8 +433,8 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
...
@@ -408,8 +433,8 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
Args:
Args:
params: a ParamsDict object to be overridden.
params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
path to
path to
a YAML file specifying the parameters to be overridden.
a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not.
is_strict: a boolean specifying whether override is strict or not.
Returns:
Returns:
...
@@ -428,12 +453,12 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
...
@@ -428,12 +453,12 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
nested_csv_str_to_json_str
(
dict_or_string_or_yaml_file
))
nested_csv_str_to_json_str
(
dict_or_string_or_yaml_file
))
except
ValueError
:
except
ValueError
:
pass
pass
params_dict
=
yaml
.
load
(
dict_or_string_or_yaml_file
)
params_dict
=
yaml
.
load
(
dict_or_string_or_yaml_file
,
Loader
=
LOADER
)
if
isinstance
(
params_dict
,
dict
):
if
isinstance
(
params_dict
,
dict
):
params
.
override
(
params_dict
,
is_strict
)
params
.
override
(
params_dict
,
is_strict
)
else
:
else
:
with
tf
.
io
.
gfile
.
GFile
(
dict_or_string_or_yaml_file
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
dict_or_string_or_yaml_file
)
as
f
:
params
.
override
(
yaml
.
load
(
f
),
is_strict
)
params
.
override
(
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
),
is_strict
)
else
:
else
:
raise
ValueError
(
'Unknown input type to parse.'
)
raise
ValueError
(
'Unknown input type to parse.'
)
return
params
return
params
official/modeling/hyperparams/params_dict_test.py
View file @
f16a7b5b
# Copyright 201
9
The TensorFlow Authors. All Rights Reserved.
# Copyright 20
2
1 The TensorFlow Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -11,7 +11,6 @@
...
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
"""Tests for params_dict.py."""
"""Tests for params_dict.py."""
...
@@ -56,8 +55,7 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -56,8 +55,7 @@ class ParamsDictTest(tf.test.TestCase):
def
test_setattr
(
self
):
def
test_setattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
=
params_dict
.
ParamsDict
()
params
.
override
(
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
params
.
c
=
'ccc'
params
.
c
=
'ccc'
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
b
,
2
)
...
@@ -65,17 +63,23 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -65,17 +63,23 @@ class ParamsDictTest(tf.test.TestCase):
def
test_getattr
(
self
):
def
test_getattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
=
params_dict
.
ParamsDict
()
params
.
override
(
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
c
,
None
)
self
.
assertEqual
(
params
.
c
,
None
)
def
test_delattr
(
self
):
def
test_delattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
=
params_dict
.
ParamsDict
()
params
.
override
(
params
.
override
({
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
,
'd'
:
{
'd1'
:
1
,
'd2'
:
10
}},
'a'
:
'aa'
,
is_strict
=
False
)
'b'
:
2
,
'c'
:
None
,
'd'
:
{
'd1'
:
1
,
'd2'
:
10
}
},
is_strict
=
False
)
del
params
.
c
del
params
.
c
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
b
,
2
)
...
@@ -87,22 +91,26 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -87,22 +91,26 @@ class ParamsDictTest(tf.test.TestCase):
def
test_contains
(
self
):
def
test_contains
(
self
):
params
=
params_dict
.
ParamsDict
()
params
=
params_dict
.
ParamsDict
()
params
.
override
(
params
.
override
({
'a'
:
'aa'
},
is_strict
=
False
)
{
'a'
:
'aa'
},
is_strict
=
False
)
self
.
assertIn
(
'a'
,
params
)
self
.
assertIn
(
'a'
,
params
)
self
.
assertNotIn
(
'b'
,
params
)
self
.
assertNotIn
(
'b'
,
params
)
def
test_get
(
self
):
def
test_get
(
self
):
params
=
params_dict
.
ParamsDict
()
params
=
params_dict
.
ParamsDict
()
params
.
override
(
params
.
override
({
'a'
:
'aa'
},
is_strict
=
False
)
{
'a'
:
'aa'
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
get
(
'a'
),
'aa'
)
self
.
assertEqual
(
params
.
get
(
'a'
),
'aa'
)
self
.
assertEqual
(
params
.
get
(
'b'
,
2
),
2
)
self
.
assertEqual
(
params
.
get
(
'b'
,
2
),
2
)
self
.
assertEqual
(
params
.
get
(
'b'
),
None
)
self
.
assertEqual
(
params
.
get
(
'b'
),
None
)
def
test_override_is_strict_true
(
self
):
def
test_override_is_strict_true
(
self
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
'cc'
,
'c2'
:
20
}})
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
'cc'
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c1'
:
'ccc'
}},
is_strict
=
True
)
params
.
override
({
'a'
:
2
,
'c'
:
{
'c1'
:
'ccc'
}},
is_strict
=
True
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c1
,
'ccc'
)
self
.
assertEqual
(
params
.
c
.
c1
,
'ccc'
)
...
@@ -112,8 +120,14 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -112,8 +120,14 @@ class ParamsDictTest(tf.test.TestCase):
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
def
test_override_is_strict_false
(
self
):
def
test_override_is_strict_false
(
self
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}})
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c3'
:
3000
}},
is_strict
=
False
)
params
.
override
({
'a'
:
2
,
'c'
:
{
'c3'
:
3000
}},
is_strict
=
False
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c3
,
3000
)
self
.
assertEqual
(
params
.
c
.
c3
,
3000
)
...
@@ -123,8 +137,14 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -123,8 +137,14 @@ class ParamsDictTest(tf.test.TestCase):
self
.
assertEqual
(
params
.
c
.
c4
,
4444
)
self
.
assertEqual
(
params
.
c
.
c4
,
4444
)
def
test_as_dict
(
self
):
def
test_as_dict
(
self
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}})
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
params_d
=
params
.
as_dict
()
params_d
=
params
.
as_dict
()
self
.
assertEqual
(
params_d
[
'a'
],
'aa'
)
self
.
assertEqual
(
params_d
[
'a'
],
'aa'
)
self
.
assertEqual
(
params_d
[
'b'
],
2
)
self
.
assertEqual
(
params_d
[
'b'
],
2
)
...
@@ -134,21 +154,27 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -134,21 +154,27 @@ class ParamsDictTest(tf.test.TestCase):
def
test_validate
(
self
):
def
test_validate
(
self
):
# Raise error due to the unknown parameter.
# Raise error due to the unknown parameter.
with
self
.
assertRaises
(
KeyError
):
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
(
{
'a'
:
1
,
'b'
:
{
'a'
:
11
}},
[
'a == c'
])
{
'a'
:
1
,
'b'
:
{
'a'
:
11
}},
[
'a == c'
]
)
params
.
validate
(
)
# OK to check equality of two nested dicts.
# OK to check equality of two nested dicts.
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
1
,
'b'
:
{
'a'
:
10
},
'c'
:
{
'a'
:
10
}},
[
'b == c'
])
'a'
:
1
,
'b'
:
{
'a'
:
10
},
'c'
:
{
'a'
:
10
}
},
[
'b == c'
])
# Raise error due to inconsistency
# Raise error due to inconsistency
with
self
.
assertRaises
(
KeyError
):
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
(
{
'a'
:
1
,
'c'
:
{
'a'
:
10
}},
[
'a == c.a'
])
{
'a'
:
1
,
'c'
:
{
'a'
:
10
}},
[
'a == c.a'
]
)
params
.
validate
(
)
# Valid rule.
# Valid rule.
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'c'
:
{
'a'
:
1
}},
[
'a == c.a'
])
{
'a'
:
1
,
'c'
:
{
'a'
:
1
}},
[
'a == c.a'
])
# Overridding violates the existing rule, raise error upon validate.
# Overridding violates the existing rule, raise error upon validate.
params
.
override
({
'a'
:
11
})
params
.
override
({
'a'
:
11
})
...
@@ -156,12 +182,21 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -156,12 +182,21 @@ class ParamsDictTest(tf.test.TestCase):
params
.
validate
()
params
.
validate
()
# Valid restrictions with constant.
# Valid restrictions with constant.
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
None
,
'c'
:
{
'a'
:
1
}},
[
'a == None'
,
'c.a == 1'
])
'a'
:
None
,
'c'
:
{
'a'
:
1
}
},
[
'a == None'
,
'c.a == 1'
])
params
.
validate
()
params
.
validate
()
with
self
.
assertRaises
(
KeyError
):
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
4
,
'c'
:
{
'a'
:
1
}},
[
'a == None'
,
'c.a == 1'
])
'a'
:
4
,
'c'
:
{
'a'
:
1
}
},
[
'a == None'
,
'c.a == 1'
])
params
.
validate
()
class
ParamsDictIOTest
(
tf
.
test
.
TestCase
):
class
ParamsDictIOTest
(
tf
.
test
.
TestCase
):
...
@@ -173,8 +208,14 @@ class ParamsDictIOTest(tf.test.TestCase):
...
@@ -173,8 +208,14 @@ class ParamsDictIOTest(tf.test.TestCase):
return
temp_file
return
temp_file
def
test_save_params_dict_to_yaml
(
self
):
def
test_save_params_dict_to_yaml
(
self
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}})
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
output_yaml_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'params.yaml'
)
output_yaml_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'params.yaml'
)
params_dict
.
save_params_dict_to_yaml
(
params
,
output_yaml_file
)
params_dict
.
save_params_dict_to_yaml
(
params
,
output_yaml_file
)
...
@@ -203,7 +244,12 @@ class ParamsDictIOTest(tf.test.TestCase):
...
@@ -203,7 +244,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_dict
(
self
):
def
test_override_params_dict_using_dict
(
self
):
params
=
params_dict
.
ParamsDict
({
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_dict
=
{
'b'
:
5.2
,
'c'
:
[
30
,
40
]}
override_dict
=
{
'b'
:
5.2
,
'c'
:
[
30
,
40
]}
params
=
params_dict
.
override_params_dict
(
params
=
params_dict
.
override_params_dict
(
params
,
override_dict
,
is_strict
=
True
)
params
,
override_dict
,
is_strict
=
True
)
...
@@ -215,7 +261,12 @@ class ParamsDictIOTest(tf.test.TestCase):
...
@@ -215,7 +261,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_yaml_string
(
self
):
def
test_override_params_dict_using_yaml_string
(
self
):
params
=
params_dict
.
ParamsDict
({
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_yaml_string
=
"'b': 5.2
\n
'c': [30, 40]"
override_yaml_string
=
"'b': 5.2
\n
'c': [30, 40]"
params
=
params_dict
.
override_params_dict
(
params
=
params_dict
.
override_params_dict
(
params
,
override_yaml_string
,
is_strict
=
True
)
params
,
override_yaml_string
,
is_strict
=
True
)
...
@@ -227,8 +278,18 @@ class ParamsDictIOTest(tf.test.TestCase):
...
@@ -227,8 +278,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_json_string
(
self
):
def
test_override_params_dict_using_json_string
(
self
):
params
=
params_dict
.
ParamsDict
({
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],},
'a'
:
1
,
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}},
'e'
:
False
})
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],
},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}
},
'e'
:
False
})
override_json_string
=
"{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
override_json_string
=
"{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params
=
params_dict
.
override_params_dict
(
params
=
params_dict
.
override_params_dict
(
params
,
override_json_string
,
is_strict
=
True
)
params
,
override_json_string
,
is_strict
=
True
)
...
@@ -240,8 +301,18 @@ class ParamsDictIOTest(tf.test.TestCase):
...
@@ -240,8 +301,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_csv_string
(
self
):
def
test_override_params_dict_using_csv_string
(
self
):
params
=
params_dict
.
ParamsDict
({
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],},
'a'
:
1
,
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}},
'e'
:
False
})
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],
},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}
},
'e'
:
False
})
override_csv_string
=
"b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
override_csv_string
=
"b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params
=
params_dict
.
override_params_dict
(
params
=
params_dict
.
override_params_dict
(
params
,
override_csv_string
,
is_strict
=
True
)
params
,
override_csv_string
,
is_strict
=
True
)
...
@@ -250,10 +321,23 @@ class ParamsDictIOTest(tf.test.TestCase):
...
@@ -250,10 +321,23 @@ class ParamsDictIOTest(tf.test.TestCase):
self
.
assertEqual
([
3
,
4
],
params
.
b
.
b2
)
self
.
assertEqual
([
3
,
4
],
params
.
b
.
b2
)
self
.
assertEqual
(
'hi, world'
,
params
.
d
.
d1
.
d2
)
self
.
assertEqual
(
'hi, world'
,
params
.
d
.
d1
.
d2
)
self
.
assertEqual
(
'gs://test'
,
params
.
e
)
self
.
assertEqual
(
'gs://test'
,
params
.
e
)
# Test different float formats
override_csv_string
=
'b.b2=-1.e-3, d.d1.d2=+0.001, e=1e+3, a=-1.5E-3'
params
=
params_dict
.
override_params_dict
(
params
,
override_csv_string
,
is_strict
=
True
)
self
.
assertEqual
(
-
1e-3
,
params
.
b
.
b2
)
self
.
assertEqual
(
0.001
,
params
.
d
.
d1
.
d2
)
self
.
assertEqual
(
1e3
,
params
.
e
)
self
.
assertEqual
(
-
1.5e-3
,
params
.
a
)
def
test_override_params_dict_using_yaml_file
(
self
):
def
test_override_params_dict_using_yaml_file
(
self
):
params
=
params_dict
.
ParamsDict
({
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_yaml_file
=
self
.
write_temp_file
(
override_yaml_file
=
self
.
write_temp_file
(
'params.yaml'
,
r
"""
'params.yaml'
,
r
"""
b: 5.2
b: 5.2
...
@@ -321,8 +405,7 @@ class IOTest(tf.test.TestCase):
...
@@ -321,8 +405,7 @@ class IOTest(tf.test.TestCase):
def
test_csv_str_load_unsupported_datatypes
(
self
):
def
test_csv_str_load_unsupported_datatypes
(
self
):
csv_str
=
'a=[[1,2,3],[4,5,6]]'
csv_str
=
'a=[[1,2,3],[4,5,6]]'
self
.
assertRaises
(
ValueError
,
self
.
assertRaises
(
ValueError
,
params_dict
.
nested_csv_str_to_json_str
,
params_dict
.
nested_csv_str_to_json_str
,
csv_str
)
csv_str
)
def
test_csv_str_to_json_str_spacing
(
self
):
def
test_csv_str_to_json_str_spacing
(
self
):
...
...
official/modeling/multitask/__init__.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/modeling/multitask/base_model.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstraction of multi-task model."""
from
typing
import
Text
,
Dict
import
tensorflow
as
tf
class
MultiTaskBaseModel
(
tf
.
Module
):
"""Base class that holds multi-task model computation."""
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_sub_tasks
=
self
.
_instantiate_sub_tasks
()
def
_instantiate_sub_tasks
(
self
)
->
Dict
[
Text
,
tf
.
keras
.
Model
]:
"""Abstract function that sets up the computation for each sub-task.
Returns:
A map from task name (as string) to a tf.keras.Model object that
represents the sub-task in the multi-task pool.
"""
raise
NotImplementedError
(
"_instantiate_sub_task_models() is not implemented."
)
@
property
def
sub_tasks
(
self
):
"""Fetch a map of task name (string) to task model (tf.keras.Model)."""
return
self
.
_sub_tasks
def
initialize
(
self
):
"""Optional function that loads a pre-train checkpoint."""
return
official/modeling/multitask/base_trainer.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Multitask base trainer implementation.
The trainer derives from the Orbit `StandardTrainer` class.
"""
from
typing
import
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
multitask
@
gin
.
configurable
class
MultiTaskBaseTrainer
(
orbit
.
StandardTrainer
):
"""Multitask base trainer."""
def
__init__
(
self
,
multi_task
:
multitask
.
MultiTask
,
multi_task_model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
optimizer
:
tf
.
optimizers
.
Optimizer
,
trainer_options
=
None
):
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_multi_task
=
multi_task
self
.
_multi_task_model
=
multi_task_model
self
.
_optimizer
=
optimizer
self
.
_training_losses
=
None
self
.
_training_metrics
=
None
self
.
_global_step
=
orbit
.
utils
.
create_global_step
()
if
hasattr
(
self
.
multi_task_model
,
"checkpoint_items"
):
checkpoint_items
=
self
.
multi_task_model
.
checkpoint_items
else
:
checkpoint_items
=
{}
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
self
.
multi_task_model
,
optimizer
=
self
.
optimizer
,
global_step
=
self
.
global_step
,
**
checkpoint_items
)
train_datasets
=
{}
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
():
train_datasets
[
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
train_data
)
super
().
__init__
(
train_dataset
=
train_datasets
,
options
=
trainer_options
or
orbit
.
StandardTrainerOptions
())
def
train_loop_begin
(
self
):
"""Clean up states that hold losses and metrics."""
for
_
,
train_loss_metric
in
self
.
training_losses
.
items
():
train_loss_metric
.
reset_states
()
for
_
,
metrics
in
self
.
training_metrics
.
items
():
for
metric
in
metrics
:
metric
.
reset_states
()
def
train_loop_end
(
self
):
"""Record loss and metric values per task."""
result
=
{}
for
task_name
,
loss
in
self
.
training_losses
.
items
():
result
[
task_name
]
=
{
loss
.
name
:
loss
.
result
()}
for
task_name
,
task_metrics
in
self
.
training_metrics
.
items
():
result
[
task_name
].
update
(
{
metric
.
name
:
metric
.
result
()
for
metric
in
task_metrics
})
# Note that, the learning rate schedule is managed by the keras optimizer
# internally, which respects the number of backward pass as `iterations`.
# The learning rate schedule does not follow the trainer logical global
# step of multiple tasks.
if
callable
(
self
.
optimizer
.
learning_rate
):
result
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
(
self
.
optimizer
.
iterations
)
else
:
result
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
return
result
@
property
def
checkpoint
(
self
):
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
@
property
def
training_losses
(
self
):
"""Access training loss metric objects for all tasks."""
if
self
.
_training_losses
is
None
:
# Builds the per-task metrics and losses.
# This the total summed training loss of tasks in the joint training.
self
.
_training_losses
=
dict
(
total_loss
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
))
for
name
in
self
.
multi_task
.
tasks
:
self
.
_training_losses
[
name
]
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
)
return
self
.
_training_losses
@
property
def
training_metrics
(
self
):
"""Access training metric metric objects for all tasks."""
if
self
.
_training_metrics
is
None
:
# Builds the per-task metrics and losses.
self
.
_training_metrics
=
{}
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
():
self
.
_training_metrics
[
name
]
=
task
.
build_metrics
(
training
=
True
)
return
self
.
_training_metrics
@
property
def
strategy
(
self
):
return
self
.
_strategy
@
property
def
multi_task
(
self
):
return
self
.
_multi_task
@
property
def
multi_task_model
(
self
):
return
self
.
_multi_task_model
@
property
def
optimizer
(
self
):
return
self
.
_optimizer
@
property
def
global_step
(
self
):
return
self
.
_global_step
def
train_step
(
self
,
iterator_map
):
"""The default train step calling the multi-task train step.
Args:
iterator_map: a dictionary of task names and per-task dataset iterators.
"""
def
step_fn
(
inputs
):
losses
=
self
.
multi_task
.
joint_train_step
(
inputs
,
multi_task_model
=
self
.
multi_task_model
,
optimizer
=
self
.
optimizer
,
task_metrics
=
self
.
training_metrics
)
for
key
,
loss
in
losses
.
items
():
self
.
training_losses
[
key
].
update_state
(
loss
)
self
.
strategy
.
run
(
step_fn
,
args
=
(
tf
.
nest
.
map_structure
(
next
,
iterator_map
),))
self
.
global_step
.
assign_add
(
1
)
official/modeling/multitask/base_trainer_test.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.base_trainer."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.modeling.multitask
import
base_trainer
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
test_utils
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
BaseTrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_joint_trainer
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
test_utils
.
MockFooTask
(
params
=
test_utils
.
FooConfig
(),
name
=
"foo"
),
test_utils
.
MockBarTask
(
params
=
test_utils
.
BarConfig
(),
name
=
"bar"
)
]
task_weights
=
{
"foo"
:
1.0
,
"bar"
:
1.0
}
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
,
task_weights
=
task_weights
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
test_trainer
=
base_trainer
.
MultiTaskBaseTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
def
test_trainer_with_configs
(
self
):
config
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
"foo"
,
task_config
=
test_utils
.
FooConfig
(),
task_weight
=
0.5
),
configs
.
TaskRoutine
(
task_name
=
"bar"
,
task_config
=
test_utils
.
BarConfig
(),
task_weight
=
0.5
)))
test_multitask
=
multitask
.
MultiTask
.
from_config
(
config
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
test_trainer
=
base_trainer
.
MultiTaskBaseTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
self
.
assertEqual
(
test_multitask
.
task_weight
(
"foo"
),
0.5
)
self
.
assertEqual
(
test_trainer
.
global_step
.
numpy
(),
5
)
self
.
assertIn
(
"learning_rate"
,
results
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/modeling/multitask/configs.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration definitions for multi-task training."""
from
typing
import
Optional
,
Tuple
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
hyperparams
@
dataclasses
.
dataclass
class
TaskRoutine
(
hyperparams
.
Config
):
task_name
:
str
=
""
task_config
:
cfg
.
TaskConfig
=
None
eval_steps
:
Optional
[
int
]
=
None
task_weight
:
Optional
[
float
]
=
1.0
@
dataclasses
.
dataclass
class
MultiTaskConfig
(
hyperparams
.
Config
):
init_checkpoint
:
str
=
""
model
:
hyperparams
.
Config
=
None
task_routines
:
Tuple
[
TaskRoutine
,
...]
=
()
@
dataclasses
.
dataclass
class
ProportionalSampleConfig
(
hyperparams
.
Config
):
alpha
:
float
=
1.0
@
dataclasses
.
dataclass
class
AnnealingSampleConfig
(
hyperparams
.
Config
):
steps_per_epoch
:
int
=
5
total_steps
:
int
=
20
@
dataclasses
.
dataclass
class
TaskSamplingConfig
(
hyperparams
.
OneOfConfig
):
type
:
str
=
""
uniform
:
hyperparams
.
Config
=
hyperparams
.
Config
()
proportional
:
ProportionalSampleConfig
=
ProportionalSampleConfig
()
annealing
:
AnnealingSampleConfig
=
AnnealingSampleConfig
()
@
dataclasses
.
dataclass
class
MultiTaskTrainerConfig
(
cfg
.
TrainerConfig
):
trainer_type
:
str
=
"interleaving"
task_sampler
:
TaskSamplingConfig
=
TaskSamplingConfig
(
type
=
"proportional"
)
@
dataclasses
.
dataclass
class
MultiTaskExperimentConfig
(
hyperparams
.
Config
):
"""An experiment config for multi-task training and multi-task evaluation."""
task
:
MultiTaskConfig
=
MultiTaskConfig
()
trainer
:
MultiTaskTrainerConfig
=
MultiTaskTrainerConfig
()
runtime
:
cfg
.
RuntimeConfig
=
cfg
.
RuntimeConfig
()
@
dataclasses
.
dataclass
class
MultiEvalExperimentConfig
(
cfg
.
ExperimentConfig
):
"""An experiment config for single-task training and multi-task evaluation.
Attributes:
eval_tasks: individual evaluation tasks.
"""
eval_tasks
:
MultiTaskConfig
=
MultiTaskConfig
()
official/modeling/multitask/evaluator.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask Evaluator implementation.
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
from
typing
import
Optional
,
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.core
import
train_utils
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
multitask
@
gin
.
configurable
class
MultiTaskEvaluator
(
orbit
.
AbstractEvaluator
):
"""Implements the common trainer shared for TensorFlow models."""
def
__init__
(
self
,
task
:
multitask
.
MultiTask
,
model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
global_step
:
Optional
[
tf
.
Variable
]
=
None
,
checkpoint_exporter
:
Optional
[
train_utils
.
BestCheckpointExporter
]
=
None
):
"""Initialize common trainer for TensorFlow models.
Args:
task: A multitask.MultiTask instance.
model: tf.keras.Model instance.
global_step: the global step variable.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_task
=
task
self
.
_model
=
model
self
.
_global_step
=
global_step
or
orbit
.
utils
.
create_global_step
()
self
.
_checkpoint_exporter
=
checkpoint_exporter
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
model
=
self
.
model
)
self
.
_validation_losses
=
None
self
.
_validation_metrics
=
None
# Builds per-task datasets.
self
.
eval_datasets
=
{}
for
name
,
task
in
self
.
task
.
tasks
.
items
():
self
.
eval_datasets
[
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
validation_data
)
# Builds per-task validation loops.
def
get_function
(
task_name
,
task
):
task_metrics
=
self
.
validation_metrics
[
task_name
]
task_loss
=
self
.
validation_losses
[
task_name
]
if
isinstance
(
self
.
model
,
base_model
.
MultiTaskBaseModel
):
model
=
self
.
model
.
sub_tasks
[
task_name
]
else
:
model
=
self
.
model
def
step_fn
(
inputs
):
logs
=
task
.
validation_step
(
inputs
,
model
=
model
,
metrics
=
task_metrics
)
task_loss
.
update_state
(
logs
[
task
.
loss
])
return
logs
@
tf
.
function
def
eval_step_fn
(
iterator
):
distributed_outputs
=
self
.
strategy
.
run
(
step_fn
,
args
=
(
next
(
iterator
),))
return
tf
.
nest
.
map_structure
(
self
.
strategy
.
experimental_local_results
,
distributed_outputs
)
return
orbit
.
utils
.
create_loop_fn
(
eval_step_fn
)
self
.
task_fns
=
{
name
:
get_function
(
name
,
task
)
for
name
,
task
in
self
.
task
.
tasks
.
items
()
}
@
property
def
strategy
(
self
):
return
self
.
_strategy
@
property
def
task
(
self
):
return
self
.
_task
@
property
def
model
(
self
):
return
self
.
_model
@
property
def
global_step
(
self
):
return
self
.
_global_step
@
property
def
validation_losses
(
self
):
"""Accesses the validation loss metric object."""
if
self
.
_validation_losses
is
None
:
# Builds the per-task metrics and losses.
self
.
_validation_losses
=
{}
for
name
in
self
.
task
.
tasks
:
self
.
_validation_losses
[
name
]
=
tf
.
keras
.
metrics
.
Mean
(
"validation_loss"
,
dtype
=
tf
.
float32
)
return
self
.
_validation_losses
@
property
def
validation_metrics
(
self
):
"""Accesses all validation metric metric objects."""
if
self
.
_validation_metrics
is
None
:
# Builds the per-task metrics and losses.
self
.
_validation_metrics
=
{}
for
name
,
task
in
self
.
task
.
tasks
.
items
():
self
.
_validation_metrics
[
name
]
=
task
.
build_metrics
(
training
=
False
)
return
self
.
_validation_metrics
@
property
def
checkpoint
(
self
):
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
def
evaluate
(
self
,
num_steps
:
tf
.
Tensor
):
"""Performs evaluation for each `EvalTask`."""
for
metric
in
self
.
validation_losses
.
values
():
metric
.
reset_states
()
for
metrics
in
self
.
validation_metrics
.
values
():
for
metric
in
metrics
:
metric
.
reset_states
()
results
=
{}
eval_iters
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_datasets
)
for
name
,
task_eval_loop
in
self
.
task_fns
.
items
():
outputs
=
None
eval_iter
=
eval_iters
[
name
]
task
=
self
.
task
.
tasks
[
name
]
task_eval_steps
=
self
.
task
.
task_eval_steps
(
name
)
or
num_steps
outputs
=
task_eval_loop
(
eval_iter
,
task_eval_steps
,
state
=
outputs
,
reduce_fn
=
task
.
aggregate_logs
)
task_metrics
=
self
.
validation_metrics
[
name
]
task_loss
=
self
.
validation_losses
[
name
]
logs
=
{}
for
metric
in
task_metrics
+
[
task_loss
]:
logs
[
metric
.
name
]
=
metric
.
result
()
if
outputs
:
metrics
=
task
.
reduce_aggregated_logs
(
outputs
,
global_step
=
self
.
global_step
)
logs
.
update
(
metrics
)
results
[
name
]
=
logs
if
self
.
_checkpoint_exporter
:
self
.
_checkpoint_exporter
.
maybe_export_checkpoint
(
self
.
checkpoint
,
results
,
self
.
global_step
.
numpy
())
return
results
official/modeling/multitask/evaluator_test.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.evaluator."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.modeling.multitask
import
evaluator
from
official.modeling.multitask
import
multitask
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
MockModel
(
tf
.
keras
.
Model
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
1
)
def
call
(
self
,
inputs
):
print
(
inputs
,
type
(
inputs
))
if
"y"
in
inputs
:
self
.
add_loss
(
tf
.
zeros
((
1
,),
dtype
=
tf
.
float32
))
else
:
self
.
add_loss
(
tf
.
ones
((
1
,),
dtype
=
tf
.
float32
))
return
self
.
dense
(
inputs
[
"x"
])
class
MockTask
(
base_task
.
Task
):
"""Mock task object for testing."""
def
build_metrics
(
self
,
training
:
bool
=
True
):
del
training
return
[
tf
.
keras
.
metrics
.
Accuracy
(
name
=
"acc"
)]
def
build_inputs
(
self
,
params
):
def
generate_data
(
_
):
x
=
tf
.
zeros
(
shape
=
(
2
,),
dtype
=
tf
.
float32
)
label
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
)
if
self
.
name
==
"bar"
:
return
dict
(
x
=
x
,
y
=
x
),
label
else
:
return
dict
(
x
=
x
),
label
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
generate_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
.
prefetch
(
buffer_size
=
1
).
batch
(
2
,
drop_remainder
=
True
)
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
logs
=
super
().
validation_step
(
inputs
,
model
,
metrics
)
logs
[
"counter"
]
=
tf
.
ones
((
1
,),
dtype
=
tf
.
float32
)
return
logs
def
aggregate_logs
(
self
,
state
,
step_outputs
):
if
state
is
None
:
state
=
{}
for
key
,
value
in
step_outputs
.
items
():
if
key
not
in
state
:
state
[
key
]
=
[]
state
[
key
].
append
(
np
.
concatenate
([
np
.
expand_dims
(
v
.
numpy
(),
axis
=
0
)
for
v
in
value
]))
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
for
k
,
v
in
aggregated_logs
.
items
():
aggregated_logs
[
k
]
=
np
.
sum
(
np
.
stack
(
v
,
axis
=
0
))
return
aggregated_logs
class
EvaluatorTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_evaluator
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
task
=
test_multitask
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"foo"
].
keys
())
self
.
assertEqual
(
results
[
"bar"
][
"validation_loss"
],
0.0
)
self
.
assertEqual
(
results
[
"foo"
][
"validation_loss"
],
1.0
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_evaluator_numpy_metrics
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
task
=
test_multitask
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
results
[
"bar"
][
"counter"
],
5.
*
distribution
.
num_replicas_in_sync
)
self
.
assertEqual
(
results
[
"foo"
][
"counter"
],
5.
*
distribution
.
num_replicas_in_sync
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/modeling/multitask/interleaving_trainer.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask trainer that interleaves each task's train step."""
from
typing
import
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
base_trainer
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
task_sampler
as
sampler
@
gin
.
configurable
class
MultiTaskInterleavingTrainer
(
base_trainer
.
MultiTaskBaseTrainer
):
"""MultiTask trainer that interleaves task update."""
def
__init__
(
self
,
multi_task
:
multitask
.
MultiTask
,
multi_task_model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
optimizer
:
tf
.
optimizers
.
Optimizer
,
task_sampler
:
sampler
.
TaskSampler
,
trainer_options
=
None
):
super
(
MultiTaskInterleavingTrainer
,
self
).
__init__
(
multi_task
=
multi_task
,
multi_task_model
=
multi_task_model
,
optimizer
=
optimizer
,
trainer_options
=
trainer_options
)
self
.
_task_sampler
=
task_sampler
# Build per task train step.
def
_get_task_step
(
task_name
,
task
):
def
step_fn
(
inputs
):
if
isinstance
(
self
.
multi_task_model
,
base_model
.
MultiTaskBaseModel
):
task_model
=
self
.
multi_task_model
.
sub_tasks
[
task_name
]
else
:
task_model
=
self
.
multi_task_model
task_logs
=
task
.
train_step
(
inputs
,
model
=
task_model
,
optimizer
=
self
.
optimizer
,
metrics
=
self
.
training_metrics
[
task_name
])
self
.
training_losses
[
task_name
].
update_state
(
task_logs
[
task
.
loss
])
return
step_fn
self
.
_task_train_step_map
=
{
name
:
_get_task_step
(
name
,
task
)
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
()
}
# TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
# on TensorBoard.
self
.
_task_step_counters
=
{
name
:
orbit
.
utils
.
create_global_step
()
for
name
in
self
.
multi_task
.
tasks
}
def
task_step_counter
(
self
,
name
):
return
self
.
_task_step_counters
[
name
]
def
train_step
(
self
,
iterator_map
):
# Sample one task to train according to a multinomial distribution
rn
=
tf
.
random
.
stateless_uniform
(
shape
=
[],
seed
=
(
0
,
self
.
global_step
))
cumulative_sample_distribution
=
self
.
_task_sampler
.
task_cumulative_distribution
(
self
.
global_step
)
# Prepend a [0.0] for indexing convenience.
cumulative_sample_distribution
=
tf
.
concat
(
[
tf
.
constant
([
0.0
],
dtype
=
tf
.
float32
),
cumulative_sample_distribution
],
axis
=
0
)
for
idx
,
(
name
,
_
)
in
enumerate
(
self
.
multi_task
.
tasks
.
items
()):
begin
=
cumulative_sample_distribution
[
idx
]
end
=
cumulative_sample_distribution
[
idx
+
1
]
if
rn
>=
begin
and
rn
<
end
:
self
.
_strategy
.
run
(
self
.
_task_train_step_map
[
name
],
args
=
(
next
(
iterator_map
[
name
]),))
self
.
global_step
.
assign_add
(
1
)
self
.
task_step_counter
(
name
).
assign_add
(
1
)
official/modeling/multitask/interleaving_trainer_test.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.interleaving_trainer."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
interleaving_trainer
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
task_sampler
from
official.modeling.multitask
import
test_utils
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
InterleavingTrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_interleaving_trainer
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
test_utils
.
MockFooTask
(
params
=
test_utils
.
FooConfig
(),
name
=
"foo"
),
test_utils
.
MockBarTask
(
params
=
test_utils
.
BarConfig
(),
name
=
"bar"
)
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
sampler
=
task_sampler
.
UniformTaskSampler
(
task_weights
=
test_multitask
.
task_weights
)
test_trainer
=
interleaving_trainer
.
MultiTaskInterleavingTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
,
task_sampler
=
sampler
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_with_configs
(
self
,
distribution
):
config
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
"foo"
,
task_config
=
test_utils
.
FooConfig
(),
task_weight
=
3.0
),
configs
.
TaskRoutine
(
task_name
=
"bar"
,
task_config
=
test_utils
.
BarConfig
(),
task_weight
=
1.0
)))
with
distribution
.
scope
():
test_multitask
=
multitask
.
MultiTask
.
from_config
(
config
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
num_step
=
1000
sampler
=
task_sampler
.
AnnealingTaskSampler
(
task_weights
=
test_multitask
.
task_weights
,
steps_per_epoch
=
num_step
/
5
,
total_steps
=
num_step
)
test_trainer
=
interleaving_trainer
.
MultiTaskInterleavingTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
,
task_sampler
=
sampler
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
num_step
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
self
.
assertEqual
(
test_trainer
.
global_step
.
numpy
(),
num_step
)
bar_sampled_step
=
test_trainer
.
task_step_counter
(
"bar"
).
numpy
()
foo_sampled_step
=
test_trainer
.
task_step_counter
(
"foo"
).
numpy
()
self
.
assertEqual
(
bar_sampled_step
+
foo_sampled_step
,
num_step
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/modeling/multitask/multitask.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental MultiTask base class for multi-task training/evaluation."""
import
abc
from
typing
import
Dict
,
List
,
Optional
,
Text
,
Union
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
configs
OptimizationConfig
=
optimization
.
OptimizationConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
class
MultiTask
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""A multi-task class to manage multiple tasks."""
def
__init__
(
self
,
tasks
:
Union
[
Dict
[
Text
,
base_task
.
Task
],
List
[
base_task
.
Task
]],
task_weights
:
Optional
[
Dict
[
str
,
Union
[
float
,
int
]]]
=
None
,
task_eval_steps
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
name
:
Optional
[
str
]
=
None
):
"""MultiTask initialization.
Args:
tasks: a list or a flat dict of Task.
task_weights: a dict of (task, task weight), task weight can be applied
directly during loss summation in a joint backward step, or it can be
used to sample task among interleaved backward step.
task_eval_steps: a dict of (task, eval steps).
name: the instance name of a MultiTask object.
"""
super
().
__init__
(
name
=
name
)
if
isinstance
(
tasks
,
list
):
self
.
_tasks
=
{}
for
task
in
tasks
:
if
task
.
name
in
self
.
_tasks
:
raise
ValueError
(
"Duplicated tasks found, task.name is %s"
%
task
.
name
)
self
.
_tasks
[
task
.
name
]
=
task
elif
isinstance
(
tasks
,
dict
):
self
.
_tasks
=
tasks
else
:
raise
ValueError
(
"The tasks argument has an invalid type: %s"
%
type
(
tasks
))
self
.
_task_eval_steps
=
task_eval_steps
or
{}
self
.
_task_eval_steps
=
dict
([
(
name
,
self
.
_task_eval_steps
.
get
(
name
,
None
))
for
name
in
self
.
tasks
])
self
.
_task_weights
=
task_weights
or
{}
self
.
_task_weights
=
dict
([
(
name
,
self
.
_task_weights
.
get
(
name
,
1.0
))
for
name
in
self
.
tasks
])
@
classmethod
def
from_config
(
cls
,
config
:
configs
.
MultiTaskConfig
,
logging_dir
=
None
):
tasks
=
{}
task_eval_steps
=
{}
task_weights
=
{}
for
task_routine
in
config
.
task_routines
:
task_name
=
task_routine
.
task_name
tasks
[
task_name
]
=
task_factory
.
get_task
(
task_routine
.
task_config
,
logging_dir
=
logging_dir
)
task_eval_steps
[
task_name
]
=
task_routine
.
eval_steps
task_weights
[
task_name
]
=
task_routine
.
task_weight
return
cls
(
tasks
,
task_eval_steps
=
task_eval_steps
,
task_weights
=
task_weights
)
@
property
def
tasks
(
self
):
return
self
.
_tasks
def
task_eval_steps
(
self
,
task_name
):
return
self
.
_task_eval_steps
[
task_name
]
def
task_weight
(
self
,
task_name
):
return
self
.
_task_weights
[
task_name
]
@
property
def
task_weights
(
self
):
return
self
.
_task_weights
@
classmethod
def
create_optimizer
(
cls
,
optimizer_config
:
OptimizationConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
return
base_task
.
Task
.
create_optimizer
(
optimizer_config
=
optimizer_config
,
runtime_config
=
runtime_config
)
def
joint_train_step
(
self
,
task_inputs
,
multi_task_model
:
base_model
.
MultiTaskBaseModel
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
task_metrics
):
"""The joint train step.
Args:
task_inputs: a dictionary of task names and per-task features.
multi_task_model: a MultiTaskBaseModel instance.
optimizer: a tf.optimizers.Optimizer.
task_metrics: a dictionary of task names and per-task metrics.
Returns:
A dictionary of losses, inculding per-task losses and their weighted sum.
"""
losses
=
{}
with
tf
.
GradientTape
()
as
tape
:
total_loss
=
0.0
for
name
,
model
in
multi_task_model
.
sub_tasks
.
items
():
inputs
=
task_inputs
[
name
]
if
isinstance
(
inputs
,
tuple
)
and
len
(
inputs
)
==
2
:
features
,
labels
=
inputs
elif
isinstance
(
inputs
,
dict
):
features
,
labels
=
inputs
,
inputs
else
:
raise
ValueError
(
"The iterator output is neither a tuple nor a "
"dictionary. It is not implemented to support "
"such outputs."
)
outputs
=
model
(
features
,
training
=
True
)
task_loss
=
self
.
tasks
[
name
].
build_losses
(
labels
,
outputs
)
task_weight
=
self
.
task_weight
(
name
)
total_loss
+=
task_weight
*
task_loss
losses
[
name
]
=
task_loss
self
.
tasks
[
name
].
process_metrics
(
task_metrics
[
name
],
labels
,
outputs
)
# Scales loss as the default gradients allreduce performs sum inside
# the optimizer.
scaled_loss
=
total_loss
/
tf
.
distribute
.
get_strategy
(
).
num_replicas_in_sync
tvars
=
multi_task_model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
losses
[
"total_loss"
]
=
total_loss
return
losses
official/modeling/multitask/task_sampler.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils to sample tasks for interleaved optimization."""
import
abc
from
typing
import
Union
,
Dict
,
Text
import
tensorflow
as
tf
from
official.modeling.multitask
import
configs
class
TaskSampler
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""An abstract class defining task sampling API for interleaving trainer."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]]):
self
.
_task_weights
=
task_weights
@
abc
.
abstractmethod
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Compute cumulative distribution to sample tasks.
It calculates the cumulative distribution of the multinomial task
distribution with respect to which to be sampled against.
Args:
global_step: A tensor indicating current progess of training.
Returns:
A float tensor with shape (#(task), 1) that represents the cumulative
sampling distribution.
"""
pass
class
UniformTaskSampler
(
TaskSampler
):
"""Sample all tasks uniformly."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]]):
super
(
UniformTaskSampler
,
self
).
__init__
(
task_weights
=
task_weights
)
self
.
_uniform_cumulative
=
tf
.
math
.
cumsum
(
tf
.
constant
(
[
1.0
/
len
(
self
.
_task_weights
)]
*
len
(
self
.
_task_weights
),
dtype
=
tf
.
float32
))
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
del
global_step
return
self
.
_uniform_cumulative
class
ProportionalTaskSampler
(
TaskSampler
):
"""Sample tasks proportional to task weights."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]],
alpha
:
float
=
1.0
):
super
(
ProportionalTaskSampler
,
self
).
__init__
(
task_weights
=
task_weights
)
self
.
_alpha
=
tf
.
cast
(
alpha
,
dtype
=
tf
.
float32
)
task_weight_dict_ordered_list
=
tf
.
constant
(
[
weight
for
_
,
weight
in
self
.
_task_weights
.
items
()],
dtype
=
tf
.
float32
)
task_sizes
=
tf
.
math
.
pow
(
task_weight_dict_ordered_list
,
self
.
_alpha
)
task_distribution
=
task_sizes
/
tf
.
reduce_sum
(
task_sizes
)
self
.
_porportional_cumulative
=
tf
.
math
.
cumsum
(
task_distribution
)
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
del
global_step
return
self
.
_porportional_cumulative
class
AnnealingTaskSampler
(
TaskSampler
):
"""Sample tasks according to task weights as well as training progress."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]],
steps_per_epoch
:
int
,
total_steps
:
int
):
super
(
AnnealingTaskSampler
,
self
).
__init__
(
task_weights
=
task_weights
)
self
.
_steps_per_epoch
=
tf
.
cast
(
steps_per_epoch
,
dtype
=
tf
.
float32
)
self
.
_total_epochs
=
tf
.
cast
(
total_steps
/
self
.
_steps_per_epoch
,
dtype
=
tf
.
float32
)
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
cur_epoch
=
tf
.
math
.
floor
(
tf
.
cast
(
global_step
,
dtype
=
tf
.
float32
)
/
self
.
_steps_per_epoch
)
alpha
=
1.0
-
0.8
*
(
cur_epoch
-
1
)
/
(
self
.
_total_epochs
-
1
+
1e-10
)
task_weight_dict_ordered_list
=
[
weight
for
_
,
weight
in
self
.
_task_weights
.
items
()
]
task_sizes
=
tf
.
math
.
pow
(
tf
.
constant
(
task_weight_dict_ordered_list
,
dtype
=
tf
.
float32
),
tf
.
cast
(
alpha
,
dtype
=
tf
.
float32
))
dynamic_task_distribution
=
task_sizes
/
tf
.
reduce_sum
(
task_sizes
)
return
tf
.
math
.
cumsum
(
dynamic_task_distribution
)
def
get_task_sampler
(
config
:
configs
.
TaskSamplingConfig
,
task_weights
:
Dict
[
Text
,
float
])
->
TaskSampler
:
"""Utils to create task sampler with configuration and task weights."""
oneof_config
=
config
.
get
()
if
config
.
type
==
'uniform'
:
return
UniformTaskSampler
(
task_weights
=
task_weights
)
elif
config
.
type
==
'proportional'
:
return
ProportionalTaskSampler
(
task_weights
=
task_weights
,
alpha
=
oneof_config
.
alpha
)
elif
config
.
type
==
'annealing'
:
return
AnnealingTaskSampler
(
task_weights
=
task_weights
,
steps_per_epoch
=
oneof_config
.
steps_per_epoch
,
total_steps
=
oneof_config
.
total_steps
)
else
:
raise
RuntimeError
(
'Task sampler type not supported'
)
official/modeling/multitask/task_sampler_test.py
0 → 100644
View file @
f16a7b5b
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.task_sampler."""
import
tensorflow
as
tf
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
task_sampler
as
sampler
class
TaskSamplerTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
TaskSamplerTest
,
self
).
setUp
()
self
.
_task_weights
=
{
'A'
:
1.0
,
'B'
:
2.0
,
'C'
:
3.0
}
def
test_uniform_sample_distribution
(
self
):
uniform_sampler
=
sampler
.
get_task_sampler
(
configs
.
TaskSamplingConfig
(
type
=
'uniform'
),
self
.
_task_weights
)
for
step
in
range
(
5
):
cumulative_distribution
=
uniform_sampler
.
task_cumulative_distribution
(
tf
.
constant
(
step
,
dtype
=
tf
.
int64
))
self
.
assertAllClose
([
0.333333
,
0.666666
,
1.0
],
cumulative_distribution
.
numpy
())
def
test_proportional_sample_distribution
(
self
):
prop_sampler
=
sampler
.
get_task_sampler
(
configs
.
TaskSamplingConfig
(
type
=
'proportional'
,
proportional
=
configs
.
ProportionalSampleConfig
(
alpha
=
2.0
)),
self
.
_task_weights
)
# CucmulativeOf(Normalize([1.0^2, 2.0^2, 3.0^2]))
for
step
in
range
(
5
):
cumulative_distribution
=
prop_sampler
.
task_cumulative_distribution
(
tf
.
constant
(
step
,
dtype
=
tf
.
int64
))
self
.
assertAllClose
([
0.07142857
,
0.35714286
,
1.0
],
cumulative_distribution
.
numpy
())
def
test_annealing_sample_distribution
(
self
):
num_epoch
=
3
step_per_epoch
=
6
annel_sampler
=
sampler
.
get_task_sampler
(
configs
.
TaskSamplingConfig
(
type
=
'annealing'
,
annealing
=
configs
.
AnnealingSampleConfig
(
steps_per_epoch
=
step_per_epoch
,
total_steps
=
step_per_epoch
*
num_epoch
)),
self
.
_task_weights
)
global_step
=
tf
.
Variable
(
0
,
dtype
=
tf
.
int64
,
name
=
'global_step'
,
trainable
=
False
)
expected_cumulative_epochs
=
[[
0.12056106
,
0.4387236
,
1.0
],
[
0.16666667
,
0.5
,
1.0
],
[
0.22477472
,
0.5654695
,
1.0
]]
for
epoch
in
range
(
num_epoch
):
for
_
in
range
(
step_per_epoch
):
cumulative_distribution
=
annel_sampler
.
task_cumulative_distribution
(
tf
.
constant
(
global_step
,
dtype
=
tf
.
int64
))
global_step
.
assign_add
(
1
)
self
.
assertAllClose
(
expected_cumulative_epochs
[
epoch
],
cumulative_distribution
.
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Prev
1
2
3
4
5
6
7
8
9
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment