Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3d61d6b3
Commit
3d61d6b3
authored
Mar 30, 2023
by
qianyj
Browse files
initial files for ResNet50
parent
d3a70caf
Changes
166
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2569 additions
and
0 deletions
+2569
-0
official/modeling/activations/gelu_test.py
official/modeling/activations/gelu_test.py
+34
-0
official/modeling/activations/relu.py
official/modeling/activations/relu.py
+31
-0
official/modeling/activations/relu_test.py
official/modeling/activations/relu_test.py
+35
-0
official/modeling/activations/sigmoid.py
official/modeling/activations/sigmoid.py
+31
-0
official/modeling/activations/sigmoid_test.py
official/modeling/activations/sigmoid_test.py
+40
-0
official/modeling/activations/swish.py
official/modeling/activations/swish.py
+72
-0
official/modeling/activations/swish_test.py
official/modeling/activations/swish_test.py
+44
-0
official/modeling/grad_utils.py
official/modeling/grad_utils.py
+151
-0
official/modeling/hyperparams/__init__.py
official/modeling/hyperparams/__init__.py
+20
-0
official/modeling/hyperparams/base_config.py
official/modeling/hyperparams/base_config.py
+306
-0
official/modeling/hyperparams/base_config_test.py
official/modeling/hyperparams/base_config_test.py
+385
-0
official/modeling/hyperparams/oneof.py
official/modeling/hyperparams/oneof.py
+57
-0
official/modeling/hyperparams/oneof_test.py
official/modeling/hyperparams/oneof_test.py
+71
-0
official/modeling/hyperparams/params_dict.py
official/modeling/hyperparams/params_dict.py
+464
-0
official/modeling/hyperparams/params_dict_test.py
official/modeling/hyperparams/params_dict_test.py
+429
-0
official/modeling/multitask/__init__.py
official/modeling/multitask/__init__.py
+14
-0
official/modeling/multitask/base_model.py
official/modeling/multitask/base_model.py
+45
-0
official/modeling/multitask/base_trainer.py
official/modeling/multitask/base_trainer.py
+170
-0
official/modeling/multitask/base_trainer_test.py
official/modeling/multitask/base_trainer_test.py
+90
-0
official/modeling/multitask/configs.py
official/modeling/multitask/configs.py
+80
-0
No files found.
official/modeling/activations/gelu_test.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the Gaussian error linear unit."""
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
@
keras_parameterized
.
run_all_keras_modes
class
GeluTest
(
keras_parameterized
.
TestCase
):
def
test_gelu
(
self
):
expected_data
=
[[
0.14967535
,
0.
,
-
0.10032465
],
[
-
0.15880796
,
-
0.04540223
,
2.9963627
]]
gelu_data
=
activations
.
gelu
([[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]])
self
.
assertAllClose
(
expected_data
,
gelu_data
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/activations/relu.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Customized Relu activation."""
import
tensorflow
as
tf
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
relu6
(
features
):
"""Computes the Relu6 activation function.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
return
tf
.
nn
.
relu6
(
features
)
official/modeling/activations/relu_test.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the customized Relu activation."""
import
tensorflow
as
tf
from
tensorflow.python.keras
import
\
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
@
keras_parameterized
.
run_all_keras_modes
class
CustomizedReluTest
(
keras_parameterized
.
TestCase
):
def
test_relu6
(
self
):
features
=
[[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]]
customized_relu6_data
=
activations
.
relu6
(
features
)
relu6_data
=
tf
.
nn
.
relu6
(
features
)
self
.
assertAllClose
(
customized_relu6_data
,
relu6_data
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/activations/sigmoid.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Customized Sigmoid activation."""
import
tensorflow
as
tf
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
hard_sigmoid
(
features
):
"""Computes the hard sigmoid activation function.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
return
tf
.
nn
.
relu6
(
features
+
tf
.
cast
(
3.
,
features
.
dtype
))
*
0.16667
official/modeling/activations/sigmoid_test.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the customized Sigmoid activation."""
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.keras
import
\
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
@
keras_parameterized
.
run_all_keras_modes
class
CustomizedSigmoidTest
(
keras_parameterized
.
TestCase
):
def
_hard_sigmoid_nn
(
self
,
x
):
x
=
np
.
float32
(
x
)
return
tf
.
nn
.
relu6
(
x
+
3.
)
*
0.16667
def
test_hard_sigmoid
(
self
):
features
=
[[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]]
customized_hard_sigmoid_data
=
activations
.
hard_sigmoid
(
features
)
sigmoid_data
=
self
.
_hard_sigmoid_nn
(
features
)
self
.
assertAllClose
(
customized_hard_sigmoid_data
,
sigmoid_data
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/activations/swish.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Customized Swish activation."""
import
tensorflow
as
tf
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
simple_swish
(
features
):
"""Computes the Swish activation function.
The tf.nn.swish operation uses a custom gradient to reduce memory usage.
Since saving custom gradients in SavedModel is currently not supported, and
one would not be able to use an exported TF-Hub module for fine-tuning, we
provide this wrapper that can allow to select whether to use the native
TensorFlow swish operation, or whether to use a customized operation that
has uses default TensorFlow gradient computation.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
return
features
*
tf
.
nn
.
sigmoid
(
features
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
hard_swish
(
features
):
"""Computes a hard version of the swish function.
This operation can be used to reduce computational cost and improve
quantization for edge devices.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
fdtype
=
features
.
dtype
return
features
*
tf
.
nn
.
relu6
(
features
+
tf
.
cast
(
3.
,
fdtype
))
*
(
1.
/
6.
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
identity
(
features
):
"""Computes the identity function.
Useful for helping in quantization.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
return
tf
.
identity
(
features
)
official/modeling/activations/swish_test.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the customized Swish activation."""
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
@
keras_parameterized
.
run_all_keras_modes
class
CustomizedSwishTest
(
keras_parameterized
.
TestCase
):
def
_hard_swish_np
(
self
,
x
):
x
=
np
.
float32
(
x
)
return
x
*
np
.
clip
(
x
+
3
,
0
,
6
)
/
6
def
test_simple_swish
(
self
):
features
=
[[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]]
customized_swish_data
=
activations
.
simple_swish
(
features
)
swish_data
=
tf
.
nn
.
swish
(
features
)
self
.
assertAllClose
(
customized_swish_data
,
swish_data
)
def
test_hard_swish
(
self
):
features
=
[[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]]
customized_swish_data
=
activations
.
hard_swish
(
features
)
swish_data
=
self
.
_hard_swish_np
(
features
)
self
.
assertAllClose
(
customized_swish_data
,
swish_data
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/grad_utils.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Some gradient util functions to help users writing custom training loop."""
from
absl
import
logging
import
tensorflow
as
tf
def
_filter_grads
(
grads_and_vars
):
"""Filter out iterable with grad equal to None."""
grads_and_vars
=
tuple
(
grads_and_vars
)
if
not
grads_and_vars
:
return
grads_and_vars
filtered
=
[]
vars_with_empty_grads
=
[]
for
grad
,
var
in
grads_and_vars
:
if
grad
is
None
:
vars_with_empty_grads
.
append
(
var
)
else
:
filtered
.
append
((
grad
,
var
))
filtered
=
tuple
(
filtered
)
if
not
filtered
:
raise
ValueError
(
"No gradients provided for any variable: %s."
%
([
v
.
name
for
_
,
v
in
grads_and_vars
],))
if
vars_with_empty_grads
:
logging
.
warning
(
(
"Gradients do not exist for variables %s when minimizing the loss."
),
([
v
.
name
for
v
in
vars_with_empty_grads
]))
return
filtered
def
_filter_and_allreduce_gradients
(
grads_and_vars
,
allreduce_precision
=
"float32"
,
bytes_per_pack
=
0
):
"""Filter None grads and then allreduce gradients in specified precision.
This utils function is used when users intent to explicitly allreduce
gradients and customize gradients operations before and after allreduce.
The allreduced gradients are then passed to optimizer.apply_gradients(
experimental_aggregate_gradients=False).
Args:
grads_and_vars: gradients and variables pairs.
allreduce_precision: Whether to allreduce gradients in float32 or float16.
bytes_per_pack: A non-negative integer. Breaks collective operations into
packs of certain size. If it's zero, all gradients are in one pack.
Returns:
pairs of allreduced non-None gradients and variables.
"""
filtered_grads_and_vars
=
_filter_grads
(
grads_and_vars
)
(
grads
,
variables
)
=
zip
(
*
filtered_grads_and_vars
)
if
allreduce_precision
==
"float16"
:
grads
=
[
tf
.
cast
(
grad
,
"float16"
)
for
grad
in
grads
]
hints
=
tf
.
distribute
.
experimental
.
CommunicationOptions
(
bytes_per_pack
=
bytes_per_pack
)
allreduced_grads
=
tf
.
distribute
.
get_strategy
(
# pylint: disable=protected-access
).
extended
.
_replica_ctx_all_reduce
(
tf
.
distribute
.
ReduceOp
.
SUM
,
grads
,
hints
)
if
allreduce_precision
==
"float16"
:
allreduced_grads
=
[
tf
.
cast
(
grad
,
"float32"
)
for
grad
in
allreduced_grads
]
return
allreduced_grads
,
variables
def
_run_callbacks
(
callbacks
,
grads_and_vars
):
for
callback
in
callbacks
:
grads_and_vars
=
callback
(
grads_and_vars
)
return
grads_and_vars
def
minimize_using_explicit_allreduce
(
tape
,
optimizer
,
loss
,
trainable_variables
,
pre_allreduce_callbacks
=
None
,
post_allreduce_callbacks
=
None
,
allreduce_bytes_per_pack
=
0
):
"""Minimizes loss for one step by updating `trainable_variables`.
Minimizes loss for one step by updating `trainable_variables`.
This explicitly performs gradient allreduce, instead of relying on implicit
allreduce in optimizer.apply_gradients(). If training using FP16 mixed
precision, explicit allreduce will aggregate gradients in FP16 format.
For TPU and GPU training using FP32, explicit allreduce will aggregate
gradients in FP32 format.
Args:
tape: An instance of `tf.GradientTape`.
optimizer: An instance of `tf.keras.optimizers.Optimizer`.
loss: the loss tensor.
trainable_variables: A list of model Variables.
pre_allreduce_callbacks: A list of callback functions that takes gradients
and model variables pairs as input, manipulate them, and returns a new
gradients and model variables pairs. The callback functions will be
invoked in the list order and before gradients are allreduced. With
mixed precision training, the pre_allreduce_allbacks will be applied on
scaled_gradients. Default is no callbacks.
post_allreduce_callbacks: A list of callback functions that takes
gradients and model variables pairs as input, manipulate them, and
returns a new gradients and model variables paris. The callback
functions will be invoked in the list order and right before gradients
are applied to variables for updates. Default is no callbacks.
allreduce_bytes_per_pack: A non-negative integer. Breaks collective
operations into packs of certain size. If it's zero, all gradients are
in one pack.
"""
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
# FP16 GPU code path
with
tape
:
scaled_loss
=
optimizer
.
get_scaled_loss
(
loss
)
scaled_grads
=
tape
.
gradient
(
scaled_loss
,
trainable_variables
)
grads_and_vars
=
zip
(
scaled_grads
,
trainable_variables
)
if
pre_allreduce_callbacks
:
grads_and_vars
=
_run_callbacks
(
pre_allreduce_callbacks
,
grads_and_vars
)
(
allreduced_scaled_grads
,
filtered_training_vars
)
=
_filter_and_allreduce_gradients
(
grads_and_vars
,
allreduce_precision
=
"float16"
,
bytes_per_pack
=
allreduce_bytes_per_pack
)
allreduced_unscaled_grads
=
optimizer
.
get_unscaled_gradients
(
allreduced_scaled_grads
)
grads_and_vars
=
zip
(
allreduced_unscaled_grads
,
filtered_training_vars
)
else
:
# TPU or FP32 GPU code path
grads
=
tape
.
gradient
(
loss
,
trainable_variables
)
grads_and_vars
=
zip
(
grads
,
trainable_variables
)
if
pre_allreduce_callbacks
:
grads_and_vars
=
_run_callbacks
(
pre_allreduce_callbacks
,
grads_and_vars
)
(
allreduced_grads
,
filtered_training_vars
)
=
_filter_and_allreduce_gradients
(
grads_and_vars
,
allreduce_precision
=
"float32"
,
bytes_per_pack
=
allreduce_bytes_per_pack
)
grads_and_vars
=
zip
(
allreduced_grads
,
filtered_training_vars
)
if
post_allreduce_callbacks
:
grads_and_vars
=
_run_callbacks
(
post_allreduce_callbacks
,
grads_and_vars
)
optimizer
.
apply_gradients
(
grads_and_vars
,
experimental_aggregate_gradients
=
False
)
official/modeling/hyperparams/__init__.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparams package definition."""
# pylint: disable=g-multiple-import
from
official.modeling.hyperparams.base_config
import
*
from
official.modeling.hyperparams.oneof
import
*
from
official.modeling.hyperparams.params_dict
import
*
official/modeling/hyperparams/base_config.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base configurations to standardize experiments."""
import
copy
import
dataclasses
import
functools
import
inspect
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Type
from
absl
import
logging
import
tensorflow
as
tf
import
yaml
from
official.modeling.hyperparams
import
params_dict
_BOUND
=
set
()
def
bind
(
config_cls
):
"""Bind a class to config cls."""
if
not
inspect
.
isclass
(
config_cls
):
raise
ValueError
(
'The bind decorator is supposed to apply on the class '
f
'attribute. Received
{
config_cls
}
, not a class.'
)
def
decorator
(
builder
):
if
config_cls
in
_BOUND
:
raise
ValueError
(
'Inside a program, we should not bind the config with a'
' class twice.'
)
if
inspect
.
isclass
(
builder
):
config_cls
.
_BUILDER
=
builder
# pylint: disable=protected-access
elif
inspect
.
isfunction
(
builder
):
def
_wrapper
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
return
builder
(
*
args
,
**
kwargs
)
config_cls
.
_BUILDER
=
_wrapper
# pylint: disable=protected-access
else
:
raise
ValueError
(
f
'The `BUILDER` type is not supported:
{
builder
}
'
)
_BOUND
.
add
(
config_cls
)
return
builder
return
decorator
@
dataclasses
.
dataclass
class
Config
(
params_dict
.
ParamsDict
):
"""The base configuration class that supports YAML/JSON based overrides.
Because of YAML/JSON serialization limitations, some semantics of dataclass
are not supported:
* It recursively enforces a allowlist of basic types and container types, so
it avoids surprises with copy and reuse caused by unanticipated types.
* Warning: it converts Dict to `Config` even within sequences,
e.g. for config = Config({'key': [([{'a': 42}],)]),
type(config.key[0][0][0]) is Config rather than dict.
If you define/annotate some field as Dict, the field will convert to a
`Config` instance and lose the dictionary type.
"""
# The class or method to bind with the params class.
_BUILDER
=
None
# It's safe to add bytes and other immutable types here.
IMMUTABLE_TYPES
=
(
str
,
int
,
float
,
bool
,
type
(
None
))
# It's safe to add set, frozenset and other collections here.
SEQUENCE_TYPES
=
(
list
,
tuple
)
default_params
:
dataclasses
.
InitVar
[
Optional
[
Mapping
[
str
,
Any
]]]
=
None
restrictions
:
dataclasses
.
InitVar
[
Optional
[
List
[
str
]]]
=
None
def
__post_init__
(
self
,
default_params
,
restrictions
):
super
().
__init__
(
default_params
=
default_params
,
restrictions
=
restrictions
)
@
property
def
BUILDER
(
self
):
return
self
.
_BUILDER
@
classmethod
def
_isvalidsequence
(
cls
,
v
):
"""Check if the input values are valid sequences.
Args:
v: Input sequence.
Returns:
True if the sequence is valid. Valid sequence includes the sequence
type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
is dict or ParamsDict.
"""
if
not
isinstance
(
v
,
cls
.
SEQUENCE_TYPES
):
return
False
return
(
all
(
isinstance
(
e
,
cls
.
IMMUTABLE_TYPES
)
for
e
in
v
)
or
all
(
isinstance
(
e
,
dict
)
for
e
in
v
)
or
all
(
isinstance
(
e
,
params_dict
.
ParamsDict
)
for
e
in
v
))
@
classmethod
def
_import_config
(
cls
,
v
,
subconfig_type
):
"""Returns v with dicts converted to Configs, recursively."""
if
not
issubclass
(
subconfig_type
,
params_dict
.
ParamsDict
):
raise
TypeError
(
'Subconfig_type should be subclass of ParamsDict, found {!r}'
.
format
(
subconfig_type
))
if
isinstance
(
v
,
cls
.
IMMUTABLE_TYPES
):
return
v
elif
isinstance
(
v
,
cls
.
SEQUENCE_TYPES
):
# Only support one layer of sequence.
if
not
cls
.
_isvalidsequence
(
v
):
raise
TypeError
(
'Invalid sequence: only supports single level {!r} of {!r} or '
'dict or ParamsDict found: {!r}'
.
format
(
cls
.
SEQUENCE_TYPES
,
cls
.
IMMUTABLE_TYPES
,
v
))
import_fn
=
functools
.
partial
(
cls
.
_import_config
,
subconfig_type
=
subconfig_type
)
return
type
(
v
)(
map
(
import_fn
,
v
))
elif
isinstance
(
v
,
params_dict
.
ParamsDict
):
# Deepcopy here is a temporary solution for preserving type in nested
# Config object.
return
copy
.
deepcopy
(
v
)
elif
isinstance
(
v
,
dict
):
return
subconfig_type
(
v
)
else
:
raise
TypeError
(
'Unknown type: {!r}'
.
format
(
type
(
v
)))
@
classmethod
def
_export_config
(
cls
,
v
):
"""Returns v with Configs converted to dicts, recursively."""
if
isinstance
(
v
,
cls
.
IMMUTABLE_TYPES
):
return
v
elif
isinstance
(
v
,
cls
.
SEQUENCE_TYPES
):
return
type
(
v
)(
map
(
cls
.
_export_config
,
v
))
elif
isinstance
(
v
,
params_dict
.
ParamsDict
):
return
v
.
as_dict
()
elif
isinstance
(
v
,
dict
):
raise
TypeError
(
'dict value not supported in converting.'
)
else
:
raise
TypeError
(
'Unknown type: {!r}'
.
format
(
type
(
v
)))
@
classmethod
def
_get_subconfig_type
(
cls
,
k
)
->
Type
[
params_dict
.
ParamsDict
]:
"""Get element type by the field name.
Args:
k: the key/name of the field.
Returns:
Config as default. If a type annotation is found for `k`,
1) returns the type of the annotation if it is subtype of ParamsDict;
2) returns the element type if the annotation of `k` is List[SubType]
or Tuple[SubType].
"""
subconfig_type
=
Config
if
k
in
cls
.
__annotations__
:
# Directly Config subtype.
type_annotation
=
cls
.
__annotations__
[
k
]
# pytype: disable=invalid-annotation
if
(
isinstance
(
type_annotation
,
type
)
and
issubclass
(
type_annotation
,
Config
)):
subconfig_type
=
cls
.
__annotations__
[
k
]
# pytype: disable=invalid-annotation
else
:
# Check if the field is a sequence of subtypes.
field_type
=
getattr
(
type_annotation
,
'__origin__'
,
type
(
None
))
if
(
isinstance
(
field_type
,
type
)
and
issubclass
(
field_type
,
cls
.
SEQUENCE_TYPES
)):
element_type
=
getattr
(
type_annotation
,
'__args__'
,
[
type
(
None
)])[
0
]
subconfig_type
=
(
element_type
if
issubclass
(
element_type
,
params_dict
.
ParamsDict
)
else
subconfig_type
)
return
subconfig_type
def
_set
(
self
,
k
,
v
):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
k: key to set.
v: value.
Raises:
RuntimeError
"""
subconfig_type
=
self
.
_get_subconfig_type
(
k
)
def
is_null
(
k
):
if
k
not
in
self
.
__dict__
or
not
self
.
__dict__
[
k
]:
return
True
return
False
if
isinstance
(
v
,
dict
):
if
is_null
(
k
):
# If the key not exist or the value is None, a new Config-family object
# sould be created for the key.
self
.
__dict__
[
k
]
=
subconfig_type
(
v
)
else
:
self
.
__dict__
[
k
].
override
(
v
)
elif
not
is_null
(
k
)
and
isinstance
(
v
,
self
.
SEQUENCE_TYPES
)
and
all
(
[
not
isinstance
(
e
,
self
.
IMMUTABLE_TYPES
)
for
e
in
v
]):
if
len
(
self
.
__dict__
[
k
])
==
len
(
v
):
for
i
in
range
(
len
(
v
)):
self
.
__dict__
[
k
][
i
].
override
(
v
[
i
])
elif
not
all
([
isinstance
(
e
,
self
.
IMMUTABLE_TYPES
)
for
e
in
v
]):
logging
.
warning
(
"The list/tuple don't match the value dictionaries provided. Thus, "
'the list/tuple is determined by the type annotation and '
'values provided. This is error-prone.'
)
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
else
:
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
else
:
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
def
__setattr__
(
self
,
k
,
v
):
if
k
==
'BUILDER'
or
k
==
'_BUILDER'
:
raise
AttributeError
(
'`BUILDER` is a property and `_BUILDER` is the '
'reserved class attribute. We should only assign '
'`_BUILDER` at the class level.'
)
if
k
not
in
self
.
RESERVED_ATTR
:
if
getattr
(
self
,
'_locked'
,
False
):
raise
ValueError
(
'The Config has been locked. '
'No change is allowed.'
)
self
.
_set
(
k
,
v
)
def
_override
(
self
,
override_dict
,
is_strict
=
True
):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
override_dict: dictionary to write to .
is_strict: If True, not allows to add new keys.
Raises:
KeyError: overriding reserved keys or keys not exist (is_strict=True).
"""
for
k
,
v
in
sorted
(
override_dict
.
items
()):
if
k
in
self
.
RESERVED_ATTR
:
raise
KeyError
(
'The key {!r} is internally reserved. '
'Can not be overridden.'
.
format
(
k
))
if
k
not
in
self
.
__dict__
:
if
is_strict
:
raise
KeyError
(
'The key {!r} does not exist in {!r}. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'
.
format
(
k
,
type
(
self
)))
else
:
self
.
_set
(
k
,
v
)
else
:
if
isinstance
(
v
,
dict
)
and
self
.
__dict__
[
k
]:
self
.
__dict__
[
k
].
_override
(
v
,
is_strict
)
# pylint: disable=protected-access
elif
isinstance
(
v
,
params_dict
.
ParamsDict
)
and
self
.
__dict__
[
k
]:
self
.
__dict__
[
k
].
_override
(
v
.
as_dict
(),
is_strict
)
# pylint: disable=protected-access
else
:
self
.
_set
(
k
,
v
)
def
as_dict
(
self
):
"""Returns a dict representation of params_dict.ParamsDict.
For the nested params_dict.ParamsDict, a nested dict will be returned.
"""
return
{
k
:
self
.
_export_config
(
v
)
for
k
,
v
in
self
.
__dict__
.
items
()
if
k
not
in
self
.
RESERVED_ATTR
}
def
replace
(
self
,
**
kwargs
):
"""Overrides/returns a unlocked copy with the current config unchanged."""
# pylint: disable=protected-access
params
=
copy
.
deepcopy
(
self
)
params
.
_locked
=
False
params
.
_override
(
kwargs
,
is_strict
=
True
)
# pylint: enable=protected-access
return
params
@
classmethod
def
from_yaml
(
cls
,
file_path
:
str
):
# Note: This only works if the Config has all default values.
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'r'
)
as
f
:
loaded
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
config
=
cls
()
config
.
override
(
loaded
)
return
config
@
classmethod
def
from_json
(
cls
,
file_path
:
str
):
"""Wrapper for `from_yaml`."""
return
cls
.
from_yaml
(
file_path
)
@
classmethod
def
from_args
(
cls
,
*
args
,
**
kwargs
):
"""Builds a config from the given list of arguments."""
attributes
=
list
(
cls
.
__annotations__
.
keys
())
default_params
=
{
a
:
p
for
a
,
p
in
zip
(
attributes
,
args
)}
default_params
.
update
(
kwargs
)
return
cls
(
default_params
=
default_params
)
official/modeling/hyperparams/base_config_test.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pprint
from
typing
import
List
,
Tuple
from
absl.testing
import
parameterized
import
dataclasses
import
tensorflow
as
tf
from
official.modeling.hyperparams
import
base_config
@
dataclasses
.
dataclass
class
DumpConfig1
(
base_config
.
Config
):
a
:
int
=
1
b
:
str
=
'text'
@
dataclasses
.
dataclass
class
DumpConfig2
(
base_config
.
Config
):
c
:
int
=
2
d
:
str
=
'text'
e
:
DumpConfig1
=
DumpConfig1
()
@
dataclasses
.
dataclass
class
DumpConfig3
(
DumpConfig2
):
f
:
int
=
2
g
:
str
=
'text'
h
:
List
[
DumpConfig1
]
=
dataclasses
.
field
(
default_factory
=
lambda
:
[
DumpConfig1
(),
DumpConfig1
()])
g
:
Tuple
[
DumpConfig1
,
...]
=
(
DumpConfig1
(),)
@
dataclasses
.
dataclass
class
DumpConfig4
(
DumpConfig2
):
x
:
int
=
3
@
dataclasses
.
dataclass
class
DummyConfig5
(
base_config
.
Config
):
y
:
Tuple
[
DumpConfig2
,
...]
=
(
DumpConfig2
(),
DumpConfig4
())
z
:
Tuple
[
str
]
=
(
'a'
,)
class
BaseConfigTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
assertHasSameTypes
(
self
,
c
,
d
,
msg
=
''
):
"""Checks if a Config has the same structure as a given dict.
Args:
c: the Config object to be check.
d: the reference dict object.
msg: The error message to show when type mismatched.
"""
# Make sure d is not a Config. Assume d is either
# dictionary or primitive type and c is the Config or primitive types.
self
.
assertNotIsInstance
(
d
,
base_config
.
Config
)
if
isinstance
(
d
,
base_config
.
Config
.
IMMUTABLE_TYPES
):
self
.
assertEqual
(
pprint
.
pformat
(
c
),
pprint
.
pformat
(
d
),
msg
=
msg
)
elif
isinstance
(
d
,
base_config
.
Config
.
SEQUENCE_TYPES
):
self
.
assertEqual
(
type
(
c
),
type
(
d
),
msg
=
msg
)
for
i
,
v
in
enumerate
(
d
):
self
.
assertHasSameTypes
(
c
[
i
],
v
,
msg
=
'{}[{!r}]'
.
format
(
msg
,
i
))
elif
isinstance
(
d
,
dict
):
self
.
assertIsInstance
(
c
,
base_config
.
Config
,
msg
=
msg
)
for
k
,
v
in
sorted
(
d
.
items
()):
self
.
assertHasSameTypes
(
getattr
(
c
,
k
),
v
,
msg
=
'{}[{!r}]'
.
format
(
msg
,
k
))
else
:
raise
TypeError
(
'Unknown type: %r'
%
type
(
d
))
def
assertImportExport
(
self
,
v
):
config
=
base_config
.
Config
({
'key'
:
v
})
back
=
config
.
as_dict
()[
'key'
]
self
.
assertEqual
(
pprint
.
pformat
(
back
),
pprint
.
pformat
(
v
))
self
.
assertHasSameTypes
(
config
.
key
,
v
,
msg
=
'=%s v'
%
pprint
.
pformat
(
v
))
def
test_invalid_keys
(
self
):
params
=
base_config
.
Config
()
with
self
.
assertRaises
(
AttributeError
):
_
=
params
.
a
def
test_cls
(
self
):
params
=
base_config
.
Config
()
with
self
.
assertRaisesRegex
(
AttributeError
,
'`BUILDER` is a property and `_BUILDER` is the reserved'
):
params
.
BUILDER
=
DumpConfig2
with
self
.
assertRaisesRegex
(
AttributeError
,
'`BUILDER` is a property and `_BUILDER` is the reserved'
):
params
.
_BUILDER
=
DumpConfig2
base_config
.
bind
(
DumpConfig1
)(
DumpConfig2
)
params
=
DumpConfig1
()
self
.
assertEqual
(
params
.
BUILDER
,
DumpConfig2
)
with
self
.
assertRaisesRegex
(
ValueError
,
'Inside a program, we should not bind'
):
base_config
.
bind
(
DumpConfig1
)(
DumpConfig2
)
def
_test
():
return
'test'
base_config
.
bind
(
DumpConfig2
)(
_test
)
params
=
DumpConfig2
()
self
.
assertEqual
(
params
.
BUILDER
(),
'test'
)
def
test_nested_config_types
(
self
):
config
=
DumpConfig3
()
self
.
assertIsInstance
(
config
.
e
,
DumpConfig1
)
self
.
assertIsInstance
(
config
.
h
[
0
],
DumpConfig1
)
self
.
assertIsInstance
(
config
.
h
[
1
],
DumpConfig1
)
self
.
assertIsInstance
(
config
.
g
[
0
],
DumpConfig1
)
config
.
override
({
'e'
:
{
'a'
:
2
,
'b'
:
'new text'
}})
self
.
assertIsInstance
(
config
.
e
,
DumpConfig1
)
self
.
assertEqual
(
config
.
e
.
a
,
2
)
self
.
assertEqual
(
config
.
e
.
b
,
'new text'
)
config
.
override
({
'h'
:
[{
'a'
:
3
,
'b'
:
'new text 2'
}]})
self
.
assertIsInstance
(
config
.
h
[
0
],
DumpConfig1
)
self
.
assertLen
(
config
.
h
,
1
)
self
.
assertEqual
(
config
.
h
[
0
].
a
,
3
)
self
.
assertEqual
(
config
.
h
[
0
].
b
,
'new text 2'
)
config
.
override
({
'g'
:
[{
'a'
:
4
,
'b'
:
'new text 3'
}]})
self
.
assertIsInstance
(
config
.
g
[
0
],
DumpConfig1
)
self
.
assertLen
(
config
.
g
,
1
)
self
.
assertEqual
(
config
.
g
[
0
].
a
,
4
)
self
.
assertEqual
(
config
.
g
[
0
].
b
,
'new text 3'
)
def
test_replace
(
self
):
config
=
DumpConfig2
()
new_config
=
config
.
replace
(
e
=
{
'a'
:
2
})
self
.
assertEqual
(
new_config
.
e
.
a
,
2
)
self
.
assertIsInstance
(
new_config
.
e
,
DumpConfig1
)
config
=
DumpConfig2
(
e
=
DumpConfig2
())
new_config
=
config
.
replace
(
e
=
{
'c'
:
4
})
self
.
assertEqual
(
new_config
.
e
.
c
,
4
)
self
.
assertIsInstance
(
new_config
.
e
,
DumpConfig2
)
config
=
DumpConfig3
()
new_config
=
config
.
replace
(
g
=
[{
'a'
:
4
,
'b'
:
'new text 3'
}])
self
.
assertIsInstance
(
new_config
.
g
[
0
],
DumpConfig1
)
self
.
assertEqual
(
new_config
.
g
[
0
].
a
,
4
)
@
parameterized
.
parameters
(
(
'_locked'
,
"The key '_locked' is internally reserved."
),
(
'_restrictions'
,
"The key '_restrictions' is internally reserved."
),
(
'aa'
,
"The key 'aa' does not exist."
),
)
def
test_key_error
(
self
,
key
,
msg
):
params
=
base_config
.
Config
()
with
self
.
assertRaisesRegex
(
KeyError
,
msg
):
params
.
override
({
key
:
True
})
@
parameterized
.
parameters
(
(
'str data'
,),
(
123
,),
(
1.23
,),
(
None
,),
([
'str'
,
1
,
2.3
,
None
],),
((
'str'
,
1
,
2.3
,
None
),),
)
def
test_import_export_immutable_types
(
self
,
v
):
self
.
assertImportExport
(
v
)
out
=
base_config
.
Config
({
'key'
:
v
})
self
.
assertEqual
(
pprint
.
pformat
(
v
),
pprint
.
pformat
(
out
.
key
))
def
test_override_is_strict_true
(
self
):
params
=
base_config
.
Config
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
'cc'
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c1'
:
'ccc'
}},
is_strict
=
True
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c1
,
'ccc'
)
with
self
.
assertRaises
(
KeyError
):
params
.
override
({
'd'
:
'ddd'
},
is_strict
=
True
)
with
self
.
assertRaises
(
KeyError
):
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
config
=
base_config
.
Config
({
'key'
:
[{
'a'
:
42
}]})
with
self
.
assertRaisesRegex
(
KeyError
,
"The key 'b' does not exist"
):
config
.
override
({
'key'
:
[{
'b'
:
43
}]})
@
parameterized
.
parameters
(
(
lambda
x
:
x
,
'Unknown type'
),
(
object
(),
'Unknown type'
),
(
set
(),
'Unknown type'
),
(
frozenset
(),
'Unknown type'
),
)
def
test_import_unsupport_types
(
self
,
v
,
msg
):
with
self
.
assertRaisesRegex
(
TypeError
,
msg
):
_
=
base_config
.
Config
({
'key'
:
v
})
@
parameterized
.
parameters
(
({
'a'
:
[{
'b'
:
2
,
},
{
'c'
:
3
,
}]
},),
({
'c'
:
[{
'f'
:
1.1
,
},
{
'h'
:
[
1
,
2
],
}]
},),
(({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
,
}
},),),
)
def
test_import_export_nested_structure
(
self
,
d
):
self
.
assertImportExport
(
d
)
@
parameterized
.
parameters
(
([{
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
}],),
(({
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
},),),
)
def
test_import_export_nested_sequences
(
self
,
v
):
self
.
assertImportExport
(
v
)
@
parameterized
.
parameters
(
([([{}],)],),
([[
'str'
,
1
,
2.3
,
None
]],),
(((
'str'
,
1
,
2.3
,
None
),),),
([
(
'str'
,
1
,
2.3
,
None
),
],),
([
(
'str'
,
1
,
2.3
,
None
),
],),
([[{
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
}]],),
([[[{
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
}]]],),
((({
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
},),),),
(((({
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
},),),),),
([({
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
},)],),
(([{
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
}],),),
)
def
test_import_export_unsupport_sequence
(
self
,
v
):
with
self
.
assertRaisesRegex
(
TypeError
,
'Invalid sequence: only supports single level'
):
_
=
base_config
.
Config
({
'key'
:
v
})
def
test_construct_subtype
(
self
):
pass
def
test_import_config
(
self
):
params
=
base_config
.
Config
({
'a'
:
[{
'b'
:
2
},
{
'c'
:
{
'd'
:
3
}}]})
self
.
assertLen
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
a
[
0
].
b
,
2
)
self
.
assertEqual
(
type
(
params
.
a
[
0
]),
base_config
.
Config
)
self
.
assertEqual
(
pprint
.
pformat
(
params
.
a
[
0
].
b
),
'2'
)
self
.
assertEqual
(
type
(
params
.
a
[
1
]),
base_config
.
Config
)
self
.
assertEqual
(
type
(
params
.
a
[
1
].
c
),
base_config
.
Config
)
self
.
assertEqual
(
pprint
.
pformat
(
params
.
a
[
1
].
c
.
d
),
'3'
)
def
test_override
(
self
):
params
=
base_config
.
Config
({
'a'
:
[{
'b'
:
2
},
{
'c'
:
{
'd'
:
3
}}]})
params
.
override
({
'a'
:
[{
'b'
:
4
},
{
'c'
:
{
'd'
:
5
}}]},
is_strict
=
False
)
self
.
assertEqual
(
type
(
params
.
a
),
list
)
self
.
assertEqual
(
type
(
params
.
a
[
0
]),
base_config
.
Config
)
self
.
assertEqual
(
pprint
.
pformat
(
params
.
a
[
0
].
b
),
'4'
)
self
.
assertEqual
(
type
(
params
.
a
[
1
]),
base_config
.
Config
)
self
.
assertEqual
(
type
(
params
.
a
[
1
].
c
),
base_config
.
Config
)
self
.
assertEqual
(
pprint
.
pformat
(
params
.
a
[
1
].
c
.
d
),
'5'
)
@
parameterized
.
parameters
(
([{}],),
(({},),),
)
def
test_config_vs_params_dict
(
self
,
v
):
d
=
{
'key'
:
v
}
self
.
assertEqual
(
type
(
base_config
.
Config
(
d
).
key
[
0
]),
base_config
.
Config
)
self
.
assertEqual
(
type
(
base_config
.
params_dict
.
ParamsDict
(
d
).
key
[
0
]),
dict
)
def
test_ppformat
(
self
):
self
.
assertEqual
(
pprint
.
pformat
([
's'
,
1
,
1.0
,
True
,
None
,
{},
[],
(),
{
(
2
,):
(
3
,
[
4
],
{
6
:
7
,
}),
8
:
9
,
}
]),
"['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]"
)
def
test_with_restrictions
(
self
):
restrictions
=
[
'e.a<c'
]
config
=
DumpConfig2
(
restrictions
=
restrictions
)
config
.
validate
()
def
test_nested_tuple
(
self
):
config
=
DummyConfig5
()
config
.
override
({
'y'
:
[{
'c'
:
4
,
'd'
:
'new text 3'
,
'e'
:
{
'a'
:
2
}
},
{
'c'
:
0
,
'd'
:
'new text 3'
,
'e'
:
{
'a'
:
2
}
}],
'z'
:
[
'a'
,
'b'
,
'c'
],
})
self
.
assertEqual
(
config
.
y
[
0
].
c
,
4
)
self
.
assertEqual
(
config
.
y
[
1
].
c
,
0
)
self
.
assertIsInstance
(
config
.
y
[
0
],
DumpConfig2
)
self
.
assertIsInstance
(
config
.
y
[
1
],
DumpConfig4
)
self
.
assertSameElements
(
config
.
z
,
[
'a'
,
'b'
,
'c'
])
def
test_override_by_empty_sequence
(
self
):
config
=
DummyConfig5
()
config
.
override
({
'y'
:
[],
'z'
:
(),
},
is_strict
=
True
)
self
.
assertEmpty
(
config
.
y
)
self
.
assertEmpty
(
config
.
z
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/hyperparams/oneof.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Config class that supports oneof functionality."""
from
typing
import
Optional
import
dataclasses
from
official.modeling.hyperparams
import
base_config
@
dataclasses
.
dataclass
class
OneOfConfig
(
base_config
.
Config
):
"""Configuration for configs with one of feature.
Attributes:
type: 'str', name of the field to select.
"""
type
:
Optional
[
str
]
=
None
def
as_dict
(
self
):
"""Returns a dict representation of OneOfConfig.
For the nested base_config.Config, a nested dict will be returned.
"""
if
self
.
type
is
None
:
return
{
'type'
:
None
}
elif
self
.
__dict__
[
'type'
]
not
in
self
.
__dict__
:
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
__dict__
[
'type'
]))
else
:
chosen_type
=
self
.
type
chosen_value
=
self
.
__dict__
[
chosen_type
]
return
{
'type'
:
self
.
type
,
chosen_type
:
self
.
_export_config
(
chosen_value
)}
def
get
(
self
):
"""Returns selected config based on the value of type.
If type is not set (None), None is returned.
"""
chosen_type
=
self
.
type
if
chosen_type
is
None
:
return
None
if
chosen_type
not
in
self
.
__dict__
:
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
type
))
return
self
.
__dict__
[
chosen_type
]
official/modeling/hyperparams/oneof_test.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
dataclasses
import
tensorflow
as
tf
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
oneof
@
dataclasses
.
dataclass
class
ResNet
(
base_config
.
Config
):
model_depth
:
int
=
50
@
dataclasses
.
dataclass
class
Backbone
(
oneof
.
OneOfConfig
):
type
:
str
=
'resnet'
resnet
:
ResNet
=
ResNet
()
not_resnet
:
int
=
2
@
dataclasses
.
dataclass
class
OutputLayer
(
oneof
.
OneOfConfig
):
type
:
str
=
'single'
single
:
int
=
1
multi_head
:
int
=
2
@
dataclasses
.
dataclass
class
Network
(
base_config
.
Config
):
backbone
:
Backbone
=
Backbone
()
output_layer
:
OutputLayer
=
OutputLayer
()
class
OneOfTest
(
tf
.
test
.
TestCase
):
def
test_to_dict
(
self
):
network_params
=
{
'backbone'
:
{
'type'
:
'resnet'
,
'resnet'
:
{
'model_depth'
:
50
}
},
'output_layer'
:
{
'type'
:
'single'
,
'single'
:
1000
}
}
network_config
=
Network
(
network_params
)
self
.
assertEqual
(
network_config
.
as_dict
(),
network_params
)
def
test_get_oneof
(
self
):
backbone
=
Backbone
()
self
.
assertIsInstance
(
backbone
.
get
(),
ResNet
)
self
.
assertEqual
(
backbone
.
get
().
as_dict
(),
{
'model_depth'
:
50
})
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/hyperparams/params_dict.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A parameter dictionary class which supports the nest structure."""
import
collections
import
copy
import
re
import
six
import
tensorflow
as
tf
import
yaml
# regex pattern that matches on key-value pairs in a comma-separated
# key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE
=
re
.
compile
(
r
"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s*
((?P<val>\'(.*?)\' # single quote
|
\"(.*?)\" # double quote
|
[^,\[]* # single value
|
\[[^\]]*\])) # list of values
($|,\s*)"""
,
re
.
VERBOSE
)
_CONST_VALUE_RE
=
re
.
compile
(
r
'(\d.*|-\d.*|None)'
)
# Yaml loader with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.
LOADER
=
yaml
.
SafeLoader
LOADER
.
add_implicit_resolver
(
'tag:yaml.org,2002:float'
,
re
.
compile
(
r
'''
^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|
[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|
\\.[0-9_]+(?:[eE][-+][0-9]+)?
|
[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$'''
,
re
.
X
),
list
(
'-+0123456789.'
))
class
ParamsDict
(
object
):
"""A hyperparameter container class."""
RESERVED_ATTR
=
[
'_locked'
,
'_restrictions'
]
def
__init__
(
self
,
default_params
=
None
,
restrictions
=
None
):
"""Instantiate a ParamsDict.
Instantiate a ParamsDict given a set of default parameters and a list of
restrictions. Upon initialization, it validates itself by checking all the
defined restrictions, and raise error if it finds inconsistency.
Args:
default_params: a Python dict or another ParamsDict object including the
default parameters to initialize.
restrictions: a list of strings, which define a list of restrictions to
ensure the consistency of different parameters internally. Each
restriction string is defined as a binary relation with a set of
operators, including {'==', '!=', '<', '<=', '>', '>='}.
"""
self
.
_locked
=
False
self
.
_restrictions
=
[]
if
restrictions
:
self
.
_restrictions
=
restrictions
if
default_params
is
None
:
default_params
=
{}
self
.
override
(
default_params
,
is_strict
=
False
)
def
_set
(
self
,
k
,
v
):
if
isinstance
(
v
,
dict
):
self
.
__dict__
[
k
]
=
ParamsDict
(
v
)
else
:
self
.
__dict__
[
k
]
=
copy
.
deepcopy
(
v
)
def
__setattr__
(
self
,
k
,
v
):
"""Sets the value of the existing key.
Note that this does not allow directly defining a new key. Use the
`override` method with `is_strict=False` instead.
Args:
k: the key string.
v: the value to be used to set the key `k`.
Raises:
KeyError: if k is not defined in the ParamsDict.
"""
if
k
not
in
ParamsDict
.
RESERVED_ATTR
:
if
k
not
in
self
.
__dict__
.
keys
():
raise
KeyError
(
'The key `%{}` does not exist. '
'To extend the existing keys, use '
'`override` with `is_strict` = True.'
.
format
(
k
))
if
self
.
_locked
:
raise
ValueError
(
'The ParamsDict has been locked. '
'No change is allowed.'
)
self
.
_set
(
k
,
v
)
def
__getattr__
(
self
,
k
):
"""Gets the value of the existing key.
Args:
k: the key string.
Returns:
the value of the key.
Raises:
AttributeError: if k is not defined in the ParamsDict.
"""
if
k
not
in
self
.
__dict__
.
keys
():
raise
AttributeError
(
'The key `{}` does not exist. '
.
format
(
k
))
return
self
.
__dict__
[
k
]
def
__contains__
(
self
,
key
):
"""Implements the membership test operator."""
return
key
in
self
.
__dict__
def
get
(
self
,
key
,
value
=
None
):
"""Accesses through built-in dictionary get method."""
return
self
.
__dict__
.
get
(
key
,
value
)
def
__delattr__
(
self
,
k
):
"""Deletes the key and removes its values.
Args:
k: the key string.
Raises:
AttributeError: if k is reserverd or not defined in the ParamsDict.
ValueError: if the ParamsDict instance has been locked.
"""
if
k
in
ParamsDict
.
RESERVED_ATTR
:
raise
AttributeError
(
'The key `{}` is reserved. No change is allowes. '
.
format
(
k
))
if
k
not
in
self
.
__dict__
.
keys
():
raise
AttributeError
(
'The key `{}` does not exist. '
.
format
(
k
))
if
self
.
_locked
:
raise
ValueError
(
'The ParamsDict has been locked. No change is allowed.'
)
del
self
.
__dict__
[
k
]
def
override
(
self
,
override_params
,
is_strict
=
True
):
"""Override the ParamsDict with a set of given params.
Args:
override_params: a dict or a ParamsDict specifying the parameters to be
overridden.
is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict. If
False, keys in `override_params` can be different from what is currently
defined in the ParamsDict. In this case, the ParamsDict will be extended
to include the new keys.
"""
if
self
.
_locked
:
raise
ValueError
(
'The ParamsDict has been locked. No change is allowed.'
)
if
isinstance
(
override_params
,
ParamsDict
):
override_params
=
override_params
.
as_dict
()
self
.
_override
(
override_params
,
is_strict
)
# pylint: disable=protected-access
def
_override
(
self
,
override_dict
,
is_strict
=
True
):
"""The implementation of `override`."""
for
k
,
v
in
six
.
iteritems
(
override_dict
):
if
k
in
ParamsDict
.
RESERVED_ATTR
:
raise
KeyError
(
'The key `%{}` is internally reserved. '
'Can not be overridden.'
)
if
k
not
in
self
.
__dict__
.
keys
():
if
is_strict
:
raise
KeyError
(
'The key `{}` does not exist. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'
.
format
(
k
))
else
:
self
.
_set
(
k
,
v
)
else
:
if
isinstance
(
v
,
dict
):
self
.
__dict__
[
k
].
_override
(
v
,
is_strict
)
# pylint: disable=protected-access
elif
isinstance
(
v
,
ParamsDict
):
self
.
__dict__
[
k
].
_override
(
v
.
as_dict
(),
is_strict
)
# pylint: disable=protected-access
else
:
self
.
__dict__
[
k
]
=
copy
.
deepcopy
(
v
)
def
lock
(
self
):
"""Makes the ParamsDict immutable."""
self
.
_locked
=
True
def
as_dict
(
self
):
"""Returns a dict representation of ParamsDict.
For the nested ParamsDict, a nested dict will be returned.
"""
params_dict
=
{}
for
k
,
v
in
six
.
iteritems
(
self
.
__dict__
):
if
k
not
in
ParamsDict
.
RESERVED_ATTR
:
if
isinstance
(
v
,
ParamsDict
):
params_dict
[
k
]
=
v
.
as_dict
()
else
:
params_dict
[
k
]
=
copy
.
deepcopy
(
v
)
return
params_dict
def
validate
(
self
):
"""Validate the parameters consistency based on the restrictions.
This method validates the internal consistency using the pre-defined list of
restrictions. A restriction is defined as a string which specfiies a binary
operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
'>='}. Note that the meaning of these operators are consistent with the
underlying Python immplementation. Users should make sure the define
restrictions on their type make sense.
For example, for a ParamsDict like the following
```
a:
a1: 1
a2: 2
b:
bb:
bb1: 10
bb2: 20
ccc:
a1: 1
a3: 3
```
one can define two restrictions like this
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
What it enforces are:
- a.a1 = 1 == b.ccc.a1 = 1
- a.a2 = 2 <= b.bb.bb2 = 20
Raises:
KeyError: if any of the following happens
(1) any of parameters in any of restrictions is not defined in
ParamsDict,
(2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported.
"""
def
_get_kv
(
dotted_string
,
params_dict
):
"""Get keys and values indicated by dotted_string."""
if
_CONST_VALUE_RE
.
match
(
dotted_string
)
is
not
None
:
const_str
=
dotted_string
if
const_str
==
'None'
:
constant
=
None
else
:
constant
=
float
(
const_str
)
return
None
,
constant
else
:
tokenized_params
=
dotted_string
.
split
(
'.'
)
v
=
params_dict
for
t
in
tokenized_params
:
v
=
v
[
t
]
return
tokenized_params
[
-
1
],
v
def
_get_kvs
(
tokens
,
params_dict
):
if
len
(
tokens
)
!=
2
:
raise
ValueError
(
'Only support binary relation in restriction.'
)
stripped_tokens
=
[
t
.
strip
()
for
t
in
tokens
]
left_k
,
left_v
=
_get_kv
(
stripped_tokens
[
0
],
params_dict
)
right_k
,
right_v
=
_get_kv
(
stripped_tokens
[
1
],
params_dict
)
return
left_k
,
left_v
,
right_k
,
right_v
params_dict
=
self
.
as_dict
()
for
restriction
in
self
.
_restrictions
:
if
'=='
in
restriction
:
tokens
=
restriction
.
split
(
'=='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
!=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'!='
in
restriction
:
tokens
=
restriction
.
split
(
'!='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
==
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<'
in
restriction
:
tokens
=
restriction
.
split
(
'<'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<='
in
restriction
:
tokens
=
restriction
.
split
(
'<='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>'
in
restriction
:
tokens
=
restriction
.
split
(
'>'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>='
in
restriction
:
tokens
=
restriction
.
split
(
'>='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
else
:
raise
ValueError
(
'Unsupported relation in restriction.'
)
def
read_yaml_to_params_dict
(
file_path
:
str
):
"""Reads a YAML file to a ParamsDict."""
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'r'
)
as
f
:
params_dict
=
yaml
.
load
(
f
,
Loader
=
LOADER
)
return
ParamsDict
(
params_dict
)
def
save_params_dict_to_yaml
(
params
,
file_path
):
"""Saves the input ParamsDict to a YAML file."""
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'w'
)
as
f
:
def
_my_list_rep
(
dumper
,
data
):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return
dumper
.
represent_sequence
(
u
'tag:yaml.org,2002:seq'
,
data
,
flow_style
=
True
)
yaml
.
add_representer
(
list
,
_my_list_rep
)
yaml
.
dump
(
params
.
as_dict
(),
f
,
default_flow_style
=
False
)
def
nested_csv_str_to_json_str
(
csv_str
):
"""Converts a nested (using '.') comma-separated k=v string to a JSON string.
Converts a comma-separated string of key/value pairs that supports
nesting of keys to a JSON string. Nesting is implemented using
'.' between levels for a given key.
Spacing between commas and = is supported (e.g. there is no difference between
"a=1,b=2", "a = 1, b = 2", or "a=1, b=2") but there should be no spaces before
keys or after values (e.g. " a=1,b=2" and "a=1,b=2 " are not supported).
Note that this will only support values supported by CSV, meaning
values such as nested lists (e.g. "a=[[1,2,3],[4,5,6]]") are not
supported. Strings are supported as well, e.g. "a='hello'".
An example conversion would be:
"a=1, b=2, c.a=2, c.b=3, d.a.a=5"
to
"{ a: 1, b : 2, c: {a : 2, b : 3}, d: {a: {a : 5}}}"
Args:
csv_str: the comma separated string.
Returns:
the converted JSON string.
Raises:
ValueError: If csv_str is not in a comma separated string or
if the string is formatted incorrectly.
"""
if
not
csv_str
:
return
''
formatted_entries
=
[]
nested_map
=
collections
.
defaultdict
(
list
)
pos
=
0
while
pos
<
len
(
csv_str
):
m
=
_PARAM_RE
.
match
(
csv_str
,
pos
)
if
not
m
:
raise
ValueError
(
'Malformed hyperparameter value while parsing '
'CSV string: %s'
%
csv_str
[
pos
:])
pos
=
m
.
end
()
# Parse the values.
m_dict
=
m
.
groupdict
()
name
=
m_dict
[
'name'
]
v
=
m_dict
[
'val'
]
# If a GCS path (e.g. gs://...) is provided, wrap this in quotes
# as yaml.load would otherwise throw an exception
if
re
.
match
(
r
'(?=[^\"\'])(?=[gs://])'
,
v
):
v
=
'
\'
{}
\'
'
.
format
(
v
)
name_nested
=
name
.
split
(
'.'
)
if
len
(
name_nested
)
>
1
:
grouping
=
name_nested
[
0
]
value
=
'.'
.
join
(
name_nested
[
1
:])
+
'='
+
v
nested_map
[
grouping
].
append
(
value
)
else
:
formatted_entries
.
append
(
'%s : %s'
%
(
name
,
v
))
for
grouping
,
value
in
nested_map
.
items
():
value
=
','
.
join
(
value
)
value
=
nested_csv_str_to_json_str
(
value
)
formatted_entries
.
append
(
'%s : %s'
%
(
grouping
,
value
))
return
'{'
+
', '
.
join
(
formatted_entries
)
+
'}'
def
override_params_dict
(
params
,
dict_or_string_or_yaml_file
,
is_strict
):
"""Override a given ParamsDict using a dict, JSON/YAML/CSV string or YAML file.
The logic of the function is outlined below:
1. Test that the input is a dict. If not, proceed to 2.
2. Tests that the input is a string. If not, raise unknown ValueError
2.1. Test if the string is in a CSV format. If so, parse.
If not, proceed to 2.2.
2.2. Try loading the string as a YAML/JSON. If successful, parse to
dict and use it to override. If not, proceed to 2.3.
2.3. Try using the string as a file path and load the YAML file.
Args:
params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or path to
a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not.
Returns:
params: the overridden ParamsDict object.
Raises:
ValueError: if failed to override the parameters.
"""
if
not
dict_or_string_or_yaml_file
:
return
params
if
isinstance
(
dict_or_string_or_yaml_file
,
dict
):
params
.
override
(
dict_or_string_or_yaml_file
,
is_strict
)
elif
isinstance
(
dict_or_string_or_yaml_file
,
six
.
string_types
):
try
:
dict_or_string_or_yaml_file
=
(
nested_csv_str_to_json_str
(
dict_or_string_or_yaml_file
))
except
ValueError
:
pass
params_dict
=
yaml
.
load
(
dict_or_string_or_yaml_file
,
Loader
=
LOADER
)
if
isinstance
(
params_dict
,
dict
):
params
.
override
(
params_dict
,
is_strict
)
else
:
with
tf
.
io
.
gfile
.
GFile
(
dict_or_string_or_yaml_file
)
as
f
:
params
.
override
(
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
),
is_strict
)
else
:
raise
ValueError
(
'Unknown input type to parse.'
)
return
params
official/modeling/hyperparams/params_dict_test.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for params_dict.py."""
import
os
import
tensorflow
as
tf
import
yaml
from
official.modeling.hyperparams
import
params_dict
class
ParamsDictTest
(
tf
.
test
.
TestCase
):
def
test_init_from_an_empty_dict
(
self
):
params
=
params_dict
.
ParamsDict
()
with
self
.
assertRaises
(
AttributeError
):
_
=
params
.
a
with
self
.
assertRaises
(
KeyError
):
params
.
a
=
'aa'
def
test_init_from_a_dict
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
})
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
def
test_init_from_a_param_dict
(
self
):
params_init
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
})
params
=
params_dict
.
ParamsDict
(
params_init
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
def
test_lock
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2
,
'c'
:
3
})
params
.
lock
()
with
self
.
assertRaises
(
ValueError
):
params
.
a
=
10
with
self
.
assertRaises
(
ValueError
):
params
.
override
({
'b'
:
20
})
with
self
.
assertRaises
(
ValueError
):
del
params
.
c
def
test_setattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
params
.
c
=
'ccc'
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
c
,
'ccc'
)
def
test_getattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
c
,
None
)
def
test_delattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
,
'd'
:
{
'd1'
:
1
,
'd2'
:
10
}
},
is_strict
=
False
)
del
params
.
c
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
with
self
.
assertRaises
(
AttributeError
):
_
=
params
.
c
del
params
.
d
with
self
.
assertRaises
(
AttributeError
):
_
=
params
.
d
.
d1
def
test_contains
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
({
'a'
:
'aa'
},
is_strict
=
False
)
self
.
assertIn
(
'a'
,
params
)
self
.
assertNotIn
(
'b'
,
params
)
def
test_get
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
({
'a'
:
'aa'
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
get
(
'a'
),
'aa'
)
self
.
assertEqual
(
params
.
get
(
'b'
,
2
),
2
)
self
.
assertEqual
(
params
.
get
(
'b'
),
None
)
def
test_override_is_strict_true
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
'cc'
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c1'
:
'ccc'
}},
is_strict
=
True
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c1
,
'ccc'
)
with
self
.
assertRaises
(
KeyError
):
params
.
override
({
'd'
:
'ddd'
},
is_strict
=
True
)
with
self
.
assertRaises
(
KeyError
):
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
def
test_override_is_strict_false
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c3'
:
3000
}},
is_strict
=
False
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c3
,
3000
)
params
.
override
({
'd'
:
'ddd'
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
d
,
'ddd'
)
params
.
override
({
'c'
:
{
'c4'
:
4444
}},
is_strict
=
False
)
self
.
assertEqual
(
params
.
c
.
c4
,
4444
)
def
test_as_dict
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
params_d
=
params
.
as_dict
()
self
.
assertEqual
(
params_d
[
'a'
],
'aa'
)
self
.
assertEqual
(
params_d
[
'b'
],
2
)
self
.
assertEqual
(
params_d
[
'c'
][
'c1'
],
10
)
self
.
assertEqual
(
params_d
[
'c'
][
'c2'
],
20
)
def
test_validate
(
self
):
# Raise error due to the unknown parameter.
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'a'
:
11
}},
[
'a == c'
])
params
.
validate
()
# OK to check equality of two nested dicts.
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'a'
:
10
},
'c'
:
{
'a'
:
10
}
},
[
'b == c'
])
# Raise error due to inconsistency
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'c'
:
{
'a'
:
10
}},
[
'a == c.a'
])
params
.
validate
()
# Valid rule.
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'c'
:
{
'a'
:
1
}},
[
'a == c.a'
])
# Overridding violates the existing rule, raise error upon validate.
params
.
override
({
'a'
:
11
})
with
self
.
assertRaises
(
KeyError
):
params
.
validate
()
# Valid restrictions with constant.
params
=
params_dict
.
ParamsDict
({
'a'
:
None
,
'c'
:
{
'a'
:
1
}
},
[
'a == None'
,
'c.a == 1'
])
params
.
validate
()
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
({
'a'
:
4
,
'c'
:
{
'a'
:
1
}
},
[
'a == None'
,
'c.a == 1'
])
params
.
validate
()
class
ParamsDictIOTest
(
tf
.
test
.
TestCase
):
def
write_temp_file
(
self
,
filename
,
text
):
temp_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
filename
)
with
tf
.
io
.
gfile
.
GFile
(
temp_file
,
'w'
)
as
writer
:
writer
.
write
(
text
)
return
temp_file
def
test_save_params_dict_to_yaml
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
output_yaml_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'params.yaml'
)
params_dict
.
save_params_dict_to_yaml
(
params
,
output_yaml_file
)
with
tf
.
io
.
gfile
.
GFile
(
output_yaml_file
,
'r'
)
as
f
:
params_d
=
yaml
.
load
(
f
)
self
.
assertEqual
(
params
.
a
,
params_d
[
'a'
])
self
.
assertEqual
(
params
.
b
,
params_d
[
'b'
])
self
.
assertEqual
(
params
.
c
.
c1
,
params_d
[
'c'
][
'c1'
])
self
.
assertEqual
(
params
.
c
.
c2
,
params_d
[
'c'
][
'c2'
])
def
test_read_yaml_to_params_dict
(
self
):
input_yaml_file
=
self
.
write_temp_file
(
'params.yaml'
,
r
"""
a: 'aa'
b: 2
c:
c1: 10
c2: 20
"""
)
params
=
params_dict
.
read_yaml_to_params_dict
(
input_yaml_file
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
c
.
c1
,
10
)
self
.
assertEqual
(
params
.
c
.
c2
,
20
)
def
test_override_params_dict_using_dict
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_dict
=
{
'b'
:
5.2
,
'c'
:
[
30
,
40
]}
params
=
params_dict
.
override_params_dict
(
params
,
override_dict
,
is_strict
=
True
)
self
.
assertEqual
(
1
,
params
.
a
)
self
.
assertEqual
(
5.2
,
params
.
b
)
self
.
assertEqual
([
30
,
40
],
params
.
c
)
self
.
assertEqual
(
'hello'
,
params
.
d
)
self
.
assertEqual
(
False
,
params
.
e
)
def
test_override_params_dict_using_yaml_string
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_yaml_string
=
"'b': 5.2
\n
'c': [30, 40]"
params
=
params_dict
.
override_params_dict
(
params
,
override_yaml_string
,
is_strict
=
True
)
self
.
assertEqual
(
1
,
params
.
a
)
self
.
assertEqual
(
5.2
,
params
.
b
)
self
.
assertEqual
([
30
,
40
],
params
.
c
)
self
.
assertEqual
(
'hello'
,
params
.
d
)
self
.
assertEqual
(
False
,
params
.
e
)
def
test_override_params_dict_using_json_string
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],
},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}
},
'e'
:
False
})
override_json_string
=
"{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params
=
params_dict
.
override_params_dict
(
params
,
override_json_string
,
is_strict
=
True
)
self
.
assertEqual
(
1
,
params
.
a
)
self
.
assertEqual
(
2
,
params
.
b
.
b1
)
self
.
assertEqual
([
3
,
4
],
params
.
b
.
b2
)
self
.
assertEqual
(
'hi'
,
params
.
d
.
d1
.
d2
)
self
.
assertEqual
(
False
,
params
.
e
)
def
test_override_params_dict_using_csv_string
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],
},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}
},
'e'
:
False
})
override_csv_string
=
"b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params
=
params_dict
.
override_params_dict
(
params
,
override_csv_string
,
is_strict
=
True
)
self
.
assertEqual
(
1
,
params
.
a
)
self
.
assertEqual
(
2
,
params
.
b
.
b1
)
self
.
assertEqual
([
3
,
4
],
params
.
b
.
b2
)
self
.
assertEqual
(
'hi, world'
,
params
.
d
.
d1
.
d2
)
self
.
assertEqual
(
'gs://test'
,
params
.
e
)
# Test different float formats
override_csv_string
=
'b.b2=-1.e-3, d.d1.d2=+0.001, e=1e+3, a=-1.5E-3'
params
=
params_dict
.
override_params_dict
(
params
,
override_csv_string
,
is_strict
=
True
)
self
.
assertEqual
(
-
1e-3
,
params
.
b
.
b2
)
self
.
assertEqual
(
0.001
,
params
.
d
.
d1
.
d2
)
self
.
assertEqual
(
1e3
,
params
.
e
)
self
.
assertEqual
(
-
1.5e-3
,
params
.
a
)
def
test_override_params_dict_using_yaml_file
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_yaml_file
=
self
.
write_temp_file
(
'params.yaml'
,
r
"""
b: 5.2
c: [30, 40]
"""
)
params
=
params_dict
.
override_params_dict
(
params
,
override_yaml_file
,
is_strict
=
True
)
self
.
assertEqual
(
1
,
params
.
a
)
self
.
assertEqual
(
5.2
,
params
.
b
)
self
.
assertEqual
([
30
,
40
],
params
.
c
)
self
.
assertEqual
(
'hello'
,
params
.
d
)
self
.
assertEqual
(
False
,
params
.
e
)
class
IOTest
(
tf
.
test
.
TestCase
):
def
test_basic_csv_str_to_json_str
(
self
):
csv_str
=
'a=1,b=2,c=3'
json_str
=
'{a : 1, b : 2, c : 3}'
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
self
.
assertEqual
(
converted_csv_str
,
json_str
)
def
test_basic_csv_str_load
(
self
):
csv_str
=
'a=1,b=2,c=3'
expected_output
=
{
'a'
:
1
,
'b'
:
2
,
'c'
:
3
}
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
converted_dict
=
yaml
.
load
(
converted_csv_str
)
self
.
assertDictEqual
(
converted_dict
,
expected_output
)
def
test_basic_nested_csv_str_to_json_str
(
self
):
csv_str
=
'a=1,b.b1=2'
json_str
=
'{a : 1, b : {b1 : 2}}'
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
self
.
assertEqual
(
converted_csv_str
,
json_str
)
def
test_basic_nested_csv_str_load
(
self
):
csv_str
=
'a=1,b.b1=2,c.c1=3'
expected_output
=
{
'a'
:
1
,
'b'
:
{
'b1'
:
2
},
'c'
:
{
'c1'
:
3
}}
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
converted_dict
=
yaml
.
load
(
converted_csv_str
)
self
.
assertDictEqual
(
converted_dict
,
expected_output
)
def
test_complex_nested_csv_str_to_json_str
(
self
):
csv_str
=
'a.aa.aaa.aaaaa.a=1'
json_str
=
'{a : {aa : {aaa : {aaaaa : {a : 1}}}}}'
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
self
.
assertEqual
(
converted_csv_str
,
json_str
)
def
test_complex_nested_csv_str_load
(
self
):
csv_str
=
'a.aa.aaa.aaaaa.a=1,a.a=2'
expected_output
=
{
'a'
:
{
'aa'
:
{
'aaa'
:
{
'aaaaa'
:
{
'a'
:
1
}}},
'a'
:
2
}}
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
converted_dict
=
yaml
.
load
(
converted_csv_str
)
self
.
assertDictEqual
(
converted_dict
,
expected_output
)
def
test_csv_str_load_supported_datatypes
(
self
):
csv_str
=
'a=1,b=2.,c=[1,2,3],d=
\'
hello, there
\'
,e=
\"
Hi.
\"
'
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
converted_dict
=
yaml
.
load
(
converted_csv_str
)
self
.
assertEqual
(
converted_dict
[
'a'
],
1
)
self
.
assertEqual
(
converted_dict
[
'b'
],
2.
)
self
.
assertEqual
(
converted_dict
[
'c'
],
[
1
,
2
,
3
])
self
.
assertEqual
(
converted_dict
[
'd'
],
'hello, there'
)
self
.
assertEqual
(
converted_dict
[
'e'
],
'Hi.'
)
def
test_csv_str_load_unsupported_datatypes
(
self
):
csv_str
=
'a=[[1,2,3],[4,5,6]]'
self
.
assertRaises
(
ValueError
,
params_dict
.
nested_csv_str_to_json_str
,
csv_str
)
def
test_csv_str_to_json_str_spacing
(
self
):
csv_str1
=
'a=1,b=2,c=3'
csv_str2
=
'a = 1, b = 2, c = 3'
json_str
=
'{a : 1, b : 2, c : 3}'
converted_csv_str1
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str1
)
converted_csv_str2
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str2
)
self
.
assertEqual
(
converted_csv_str1
,
converted_csv_str2
)
self
.
assertEqual
(
converted_csv_str1
,
json_str
)
self
.
assertEqual
(
converted_csv_str2
,
json_str
)
def
test_gcs_added_quotes
(
self
):
csv_str
=
'a=gs://abc, b=gs://def'
expected_output
=
'{a :
\'
gs://abc
\'
, b :
\'
gs://def
\'
}'
converted_csv_str
=
params_dict
.
nested_csv_str_to_json_str
(
csv_str
)
self
.
assertEqual
(
converted_csv_str
,
expected_output
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/multitask/__init__.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/modeling/multitask/base_model.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstraction of multi-task model."""
from
typing
import
Text
,
Dict
import
tensorflow
as
tf
class
MultiTaskBaseModel
(
tf
.
Module
):
"""Base class that holds multi-task model computation."""
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_sub_tasks
=
self
.
_instantiate_sub_tasks
()
def
_instantiate_sub_tasks
(
self
)
->
Dict
[
Text
,
tf
.
keras
.
Model
]:
"""Abstract function that sets up the computation for each sub-task.
Returns:
A map from task name (as string) to a tf.keras.Model object that
represents the sub-task in the multi-task pool.
"""
raise
NotImplementedError
(
"_instantiate_sub_task_models() is not implemented."
)
@
property
def
sub_tasks
(
self
):
"""Fetch a map of task name (string) to task model (tf.keras.Model)."""
return
self
.
_sub_tasks
def
initialize
(
self
):
"""Optional function that loads a pre-train checkpoint."""
return
official/modeling/multitask/base_trainer.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask base trainer implementation.
The trainer derives from the Orbit `StandardTrainer` class.
"""
from
typing
import
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.modeling
import
optimization
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
multitask
@
gin
.
configurable
class
MultiTaskBaseTrainer
(
orbit
.
StandardTrainer
):
"""Multitask base trainer."""
def
__init__
(
self
,
multi_task
:
multitask
.
MultiTask
,
multi_task_model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
optimizer
:
tf
.
optimizers
.
Optimizer
,
trainer_options
=
None
,
train_datasets
=
None
):
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_multi_task
=
multi_task
self
.
_multi_task_model
=
multi_task_model
self
.
_optimizer
=
optimizer
self
.
_training_losses
=
None
self
.
_training_metrics
=
None
self
.
_global_step
=
orbit
.
utils
.
create_global_step
()
# Creates a shadow copy of the weights to store weights moving average.
if
isinstance
(
self
.
_optimizer
,
optimization
.
ExponentialMovingAverage
)
and
not
self
.
_optimizer
.
has_shadow_copy
:
self
.
_optimizer
.
shadow_copy
(
multi_task_model
)
if
hasattr
(
self
.
multi_task_model
,
"checkpoint_items"
):
checkpoint_items
=
self
.
multi_task_model
.
checkpoint_items
else
:
checkpoint_items
=
{}
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
self
.
multi_task_model
,
optimizer
=
self
.
optimizer
,
global_step
=
self
.
global_step
,
**
checkpoint_items
)
if
train_datasets
is
None
:
train_datasets
=
{}
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
():
train_datasets
[
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
train_data
)
super
().
__init__
(
train_dataset
=
train_datasets
,
options
=
trainer_options
or
orbit
.
StandardTrainerOptions
())
def
train_loop_begin
(
self
):
"""Clean up states that hold losses and metrics."""
for
_
,
train_loss_metric
in
self
.
training_losses
.
items
():
train_loss_metric
.
reset_states
()
for
_
,
metrics
in
self
.
training_metrics
.
items
():
for
metric
in
metrics
:
metric
.
reset_states
()
def
train_loop_end
(
self
):
"""Record loss and metric values per task."""
result
=
{}
for
task_name
,
loss
in
self
.
training_losses
.
items
():
result
[
task_name
]
=
{
loss
.
name
:
loss
.
result
()}
for
task_name
,
task_metrics
in
self
.
training_metrics
.
items
():
result
[
task_name
].
update
(
{
metric
.
name
:
metric
.
result
()
for
metric
in
task_metrics
})
# Note that, the learning rate schedule is managed by the keras optimizer
# internally, which respects the number of backward pass as `iterations`.
# The learning rate schedule does not follow the trainer logical global
# step of multiple tasks.
if
callable
(
self
.
optimizer
.
learning_rate
):
result
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
(
self
.
optimizer
.
iterations
)
else
:
result
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
return
result
@
property
def
checkpoint
(
self
):
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
@
property
def
training_losses
(
self
):
"""Access training loss metric objects for all tasks."""
if
self
.
_training_losses
is
None
:
# Builds the per-task metrics and losses.
# This the total summed training loss of tasks in the joint training.
self
.
_training_losses
=
dict
(
total_loss
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
))
for
name
in
self
.
multi_task
.
tasks
:
self
.
_training_losses
[
name
]
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
)
return
self
.
_training_losses
@
property
def
training_metrics
(
self
):
"""Access training metric metric objects for all tasks."""
if
self
.
_training_metrics
is
None
:
# Builds the per-task metrics and losses.
self
.
_training_metrics
=
{}
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
():
self
.
_training_metrics
[
name
]
=
task
.
build_metrics
(
training
=
True
)
return
self
.
_training_metrics
@
property
def
strategy
(
self
):
return
self
.
_strategy
@
property
def
multi_task
(
self
):
return
self
.
_multi_task
@
property
def
multi_task_model
(
self
):
return
self
.
_multi_task_model
@
property
def
optimizer
(
self
):
return
self
.
_optimizer
@
property
def
global_step
(
self
):
return
self
.
_global_step
def
train_step
(
self
,
iterator_map
):
"""The default train step calling the multi-task train step.
Args:
iterator_map: a dictionary of task names and per-task dataset iterators.
"""
def
step_fn
(
inputs
):
losses
=
self
.
multi_task
.
joint_train_step
(
inputs
,
multi_task_model
=
self
.
multi_task_model
,
optimizer
=
self
.
optimizer
,
task_metrics
=
self
.
training_metrics
)
for
key
,
loss
in
losses
.
items
():
self
.
training_losses
[
key
].
update_state
(
loss
)
self
.
strategy
.
run
(
step_fn
,
args
=
(
tf
.
nest
.
map_structure
(
next
,
iterator_map
),))
self
.
global_step
.
assign_add
(
1
)
official/modeling/multitask/base_trainer_test.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.base_trainer."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.modeling.multitask
import
base_trainer
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
test_utils
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
BaseTrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_joint_trainer
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
test_utils
.
MockFooTask
(
params
=
test_utils
.
FooConfig
(),
name
=
"foo"
),
test_utils
.
MockBarTask
(
params
=
test_utils
.
BarConfig
(),
name
=
"bar"
)
]
task_weights
=
{
"foo"
:
1.0
,
"bar"
:
1.0
}
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
,
task_weights
=
task_weights
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
test_trainer
=
base_trainer
.
MultiTaskBaseTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
def
test_trainer_with_configs
(
self
):
config
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
"foo"
,
task_config
=
test_utils
.
FooConfig
(),
task_weight
=
0.5
),
configs
.
TaskRoutine
(
task_name
=
"bar"
,
task_config
=
test_utils
.
BarConfig
(),
task_weight
=
0.5
)))
test_multitask
=
multitask
.
MultiTask
.
from_config
(
config
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
test_trainer
=
base_trainer
.
MultiTaskBaseTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
self
.
assertEqual
(
test_multitask
.
task_weight
(
"foo"
),
0.5
)
self
.
assertEqual
(
test_trainer
.
global_step
.
numpy
(),
5
)
self
.
assertIn
(
"learning_rate"
,
results
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/modeling/multitask/configs.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration definitions for multi-task training."""
from
typing
import
Optional
,
Tuple
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
hyperparams
@
dataclasses
.
dataclass
class
TaskRoutine
(
hyperparams
.
Config
):
# TODO(hongkuny): deprecate the task_name once we migrated client code.
task_name
:
str
=
""
task_config
:
cfg
.
TaskConfig
=
None
eval_steps
:
Optional
[
int
]
=
None
task_weight
:
Optional
[
float
]
=
1.0
@
dataclasses
.
dataclass
class
MultiTaskConfig
(
hyperparams
.
Config
):
init_checkpoint
:
str
=
""
model
:
hyperparams
.
Config
=
None
task_routines
:
Tuple
[
TaskRoutine
,
...]
=
()
@
dataclasses
.
dataclass
class
ProportionalSampleConfig
(
hyperparams
.
Config
):
alpha
:
float
=
1.0
@
dataclasses
.
dataclass
class
AnnealingSampleConfig
(
hyperparams
.
Config
):
steps_per_epoch
:
int
=
5
total_steps
:
int
=
20
@
dataclasses
.
dataclass
class
TaskSamplingConfig
(
hyperparams
.
OneOfConfig
):
type
:
str
=
""
uniform
:
hyperparams
.
Config
=
hyperparams
.
Config
()
proportional
:
ProportionalSampleConfig
=
ProportionalSampleConfig
()
annealing
:
AnnealingSampleConfig
=
AnnealingSampleConfig
()
@
dataclasses
.
dataclass
class
MultiTaskTrainerConfig
(
cfg
.
TrainerConfig
):
trainer_type
:
str
=
"interleaving"
task_sampler
:
TaskSamplingConfig
=
TaskSamplingConfig
(
type
=
"proportional"
)
@
dataclasses
.
dataclass
class
MultiTaskExperimentConfig
(
hyperparams
.
Config
):
"""An experiment config for multi-task training and multi-task evaluation."""
task
:
MultiTaskConfig
=
MultiTaskConfig
()
trainer
:
MultiTaskTrainerConfig
=
MultiTaskTrainerConfig
()
runtime
:
cfg
.
RuntimeConfig
=
cfg
.
RuntimeConfig
()
@
dataclasses
.
dataclass
class
MultiEvalExperimentConfig
(
cfg
.
ExperimentConfig
):
"""An experiment config for single-task training and multi-task evaluation.
Attributes:
eval_tasks: individual evaluation tasks.
"""
eval_tasks
:
Tuple
[
TaskRoutine
,
...]
=
()
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment