Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
a32ffa95
Commit
a32ffa95
authored
Feb 03, 2023
by
qianyj
Browse files
update TensorFlow2x test method
parent
e286da17
Changes
268
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2603 additions
and
0 deletions
+2603
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/activations/relu.py
...ation/models-master/official/modeling/activations/relu.py
+31
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/activations/relu_test.py
.../models-master/official/modeling/activations/relu_test.py
+35
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/activations/sigmoid.py
...on/models-master/official/modeling/activations/sigmoid.py
+31
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/activations/sigmoid_test.py
...dels-master/official/modeling/activations/sigmoid_test.py
+40
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/activations/swish.py
...tion/models-master/official/modeling/activations/swish.py
+72
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/activations/swish_test.py
...models-master/official/modeling/activations/swish_test.py
+44
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/experimental/tf2_utils_2x_wide.py
.../modeling/fast_training/experimental/tf2_utils_2x_wide.py
+186
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/experimental/tf2_utils_2x_wide_test.py
...ling/fast_training/experimental/tf2_utils_2x_wide_test.py
+101
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/policies.py
...r/official/modeling/fast_training/progressive/policies.py
+178
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/train.py
...ster/official/modeling/fast_training/progressive/train.py
+69
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/train_lib.py
.../official/modeling/fast_training/progressive/train_lib.py
+126
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/train_lib_test.py
...cial/modeling/fast_training/progressive/train_lib_test.py
+183
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/trainer.py
...er/official/modeling/fast_training/progressive/trainer.py
+294
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/trainer_test.py
...ficial/modeling/fast_training/progressive/trainer_test.py
+238
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/utils.py
...ster/official/modeling/fast_training/progressive/utils.py
+56
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/grad_utils.py
...ssification/models-master/official/modeling/grad_utils.py
+151
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/__init__.py
...n/models-master/official/modeling/hyperparams/__init__.py
+20
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/base_config.py
...odels-master/official/modeling/hyperparams/base_config.py
+306
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/base_config_test.py
...-master/official/modeling/hyperparams/base_config_test.py
+385
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/oneof.py
...tion/models-master/official/modeling/hyperparams/oneof.py
+57
-0
No files found.
Too many changes to show.
To preserve performance only
268 of 268+
files are displayed.
Plain diff
Email patch
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/activations/relu.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Customized Relu activation."""
import
tensorflow
as
tf
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
relu6
(
features
):
"""Computes the Relu6 activation function.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
return
tf
.
nn
.
relu6
(
features
)
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/activations/relu_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the customized Relu activation."""
import
tensorflow
as
tf
from
tensorflow.python.keras
import
\
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
@
keras_parameterized
.
run_all_keras_modes
class
CustomizedReluTest
(
keras_parameterized
.
TestCase
):
def
test_relu6
(
self
):
features
=
[[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]]
customized_relu6_data
=
activations
.
relu6
(
features
)
relu6_data
=
tf
.
nn
.
relu6
(
features
)
self
.
assertAllClose
(
customized_relu6_data
,
relu6_data
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/activations/sigmoid.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Customized Sigmoid activation."""
import
tensorflow
as
tf
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
hard_sigmoid
(
features
):
"""Computes the hard sigmoid activation function.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
return
tf
.
nn
.
relu6
(
features
+
tf
.
cast
(
3.
,
features
.
dtype
))
*
0.16667
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/activations/sigmoid_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the customized Sigmoid activation."""
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.keras
import
\
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
@
keras_parameterized
.
run_all_keras_modes
class
CustomizedSigmoidTest
(
keras_parameterized
.
TestCase
):
def
_hard_sigmoid_nn
(
self
,
x
):
x
=
np
.
float32
(
x
)
return
tf
.
nn
.
relu6
(
x
+
3.
)
*
0.16667
def
test_hard_sigmoid
(
self
):
features
=
[[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]]
customized_hard_sigmoid_data
=
activations
.
hard_sigmoid
(
features
)
sigmoid_data
=
self
.
_hard_sigmoid_nn
(
features
)
self
.
assertAllClose
(
customized_hard_sigmoid_data
,
sigmoid_data
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/activations/swish.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Customized Swish activation."""
import
tensorflow
as
tf
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
simple_swish
(
features
):
"""Computes the Swish activation function.
The tf.nn.swish operation uses a custom gradient to reduce memory usage.
Since saving custom gradients in SavedModel is currently not supported, and
one would not be able to use an exported TF-Hub module for fine-tuning, we
provide this wrapper that can allow to select whether to use the native
TensorFlow swish operation, or whether to use a customized operation that
has uses default TensorFlow gradient computation.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
return
features
*
tf
.
nn
.
sigmoid
(
features
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
hard_swish
(
features
):
"""Computes a hard version of the swish function.
This operation can be used to reduce computational cost and improve
quantization for edge devices.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
fdtype
=
features
.
dtype
return
features
*
tf
.
nn
.
relu6
(
features
+
tf
.
cast
(
3.
,
fdtype
))
*
(
1.
/
6.
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
def
identity
(
features
):
"""Computes the identity function.
Useful for helping in quantization.
Args:
features: A `Tensor` representing preactivation values.
Returns:
The activation value.
"""
features
=
tf
.
convert_to_tensor
(
features
)
return
tf
.
identity
(
features
)
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/activations/swish_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the customized Swish activation."""
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.keras
import
keras_parameterized
# pylint: disable=g-direct-tensorflow-import
from
official.modeling
import
activations
@
keras_parameterized
.
run_all_keras_modes
class
CustomizedSwishTest
(
keras_parameterized
.
TestCase
):
def
_hard_swish_np
(
self
,
x
):
x
=
np
.
float32
(
x
)
return
x
*
np
.
clip
(
x
+
3
,
0
,
6
)
/
6
def
test_simple_swish
(
self
):
features
=
[[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]]
customized_swish_data
=
activations
.
simple_swish
(
features
)
swish_data
=
tf
.
nn
.
swish
(
features
)
self
.
assertAllClose
(
customized_swish_data
,
swish_data
)
def
test_hard_swish
(
self
):
features
=
[[.
25
,
0
,
-
.
25
],
[
-
1
,
-
2
,
3
]]
customized_swish_data
=
activations
.
hard_swish
(
features
)
swish_data
=
self
.
_hard_swish_np
(
features
)
self
.
assertAllClose
(
customized_swish_data
,
swish_data
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/experimental/tf2_utils_2x_wide.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stacking model horizontally."""
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
def
expand_vector
(
v
:
np
.
ndarray
)
->
np
.
ndarray
:
"""Expands a vector with batch dimensions.
Equivalent to expand_1_axis(v, epsilon=0.0, axis=-1)
Args:
v: A vector with shape [..., a].
Returns:
A vector with shape [..., 2 * a].
"""
return
np
.
repeat
(
v
,
2
,
axis
=-
1
)
def
expand_1_axis
(
w
:
np
.
ndarray
,
epsilon
:
float
,
axis
:
int
)
->
np
.
ndarray
:
"""Expands either the first dimension or the last dimension of w.
If `axis = 0`, the following constraint will be satisfied:
matmul(x, w) ==
matmul(expand_vector(x), expand_1_axis(w, epsilon=0.1, axis=0))
If `axis = -1`, the following constraint will be satisfied if `epsilon = 0.0`:
expand_vector(matmul(x, w)) ==
2 * matmul(x, expand_1_axis(w, epsilon=0.0, axis=-1))
Args:
w: Numpy array of shape [a_0, a_1, ..., a_i-1, a_i].
epsilon: Symmetric Noise added to expanded tensor.
axis: Must be either 0 or -1.
Returns:
Expanded numpy array.
"""
assert
axis
in
(
0
,
-
1
),
(
"Only support expanding the first or the last dimension. "
"Got: {}"
.
format
(
axis
))
rank
=
len
(
w
.
shape
)
d_w
=
np
.
random
.
normal
(
np
.
zeros_like
(
w
),
np
.
fabs
(
w
)
*
epsilon
,
w
.
shape
)
d_w
=
np
.
repeat
(
d_w
,
2
,
axis
=
axis
)
sign_flip
=
np
.
array
([
1
,
-
1
])
for
_
in
range
(
rank
-
1
):
sign_flip
=
np
.
expand_dims
(
sign_flip
,
axis
=-
1
if
axis
==
0
else
0
)
sign_flip
=
np
.
tile
(
sign_flip
,
[
w
.
shape
[
0
]]
+
[
1
]
*
(
rank
-
2
)
+
[
w
.
shape
[
-
1
]])
d_w
*=
sign_flip
w_expand
=
(
np
.
repeat
(
w
,
2
,
axis
=
axis
)
+
d_w
)
/
2
return
w_expand
def
expand_2_axes
(
w
:
np
.
ndarray
,
epsilon
:
float
)
->
np
.
ndarray
:
"""Expands the first dimension and the last dimension of w.
The following constraint will be satisfied:
expand_vector(matmul(x, w)) == matmul(expand_vector(x), expand_2_axes(w))
Args:
w: Numpy array of shape [a_0, a_1, ..., a_i-1, a_i].
epsilon: Symmetric Noise added to expanded tensor.
Returns:
Expanded numpy array.
"""
rank
=
len
(
w
.
shape
)
d_w
=
np
.
random
.
normal
(
np
.
zeros_like
(
w
),
np
.
fabs
(
w
)
*
epsilon
,
w
.
shape
)
d_w
=
np
.
repeat
(
np
.
repeat
(
d_w
,
2
,
axis
=
0
),
2
,
axis
=-
1
)
sign_flip
=
np
.
array
([
1
,
-
1
])
for
_
in
range
(
rank
-
1
):
sign_flip
=
np
.
expand_dims
(
sign_flip
,
axis
=-
1
)
sign_flip
=
np
.
tile
(
sign_flip
,
[
w
.
shape
[
0
]]
+
[
1
]
*
(
rank
-
2
)
+
[
w
.
shape
[
-
1
]
*
2
])
d_w
*=
sign_flip
w_expand
=
(
np
.
repeat
(
np
.
repeat
(
w
,
2
,
axis
=
0
),
2
,
axis
=-
1
)
+
d_w
)
/
2
return
w_expand
def
var_to_var
(
var_from
:
tf
.
Variable
,
var_to
:
tf
.
Variable
,
epsilon
:
float
):
"""Expands a variable to another variable.
Assume the shape of `var_from` is (a, b, ..., y, z), the shape of `var_to`
can be (a, ..., z * 2), (a * 2, ..., z * 2), (a * 2, ..., z)
If the shape of `var_to` is (a, ..., 2 * z):
For any x, tf.matmul(x, var_to) ~= expand_vector(tf.matmul(x, var_from)) / 2
Not that there will be noise added to the left hand side, if epsilon != 0.
If the shape of `var_to` is (2 * a, ..., z):
For any x, tf.matmul(expand_vector(x), var_to) == tf.matmul(x, var_from)
If the shape of `var_to` is (2 * a, ..., 2 * z):
For any x, tf.matmul(expand_vector(x), var_to) ==
expand_vector(tf.matmul(expand_vector(x), var_from))
Args:
var_from: input variable to expand.
var_to: output variable.
epsilon: the noise ratio that will be added, when splitting `var_from`.
"""
shape_from
=
var_from
.
shape
shape_to
=
var_to
.
shape
if
shape_from
==
shape_to
:
var_to
.
assign
(
var_from
)
elif
len
(
shape_from
)
==
1
and
len
(
shape_to
)
==
1
:
var_to
.
assign
(
expand_vector
(
var_from
.
numpy
()))
elif
shape_from
[
0
]
*
2
==
shape_to
[
0
]
and
shape_from
[
-
1
]
==
shape_to
[
-
1
]:
var_to
.
assign
(
expand_1_axis
(
var_from
.
numpy
(),
epsilon
=
epsilon
,
axis
=
0
))
elif
shape_from
[
0
]
==
shape_to
[
0
]
and
shape_from
[
-
1
]
*
2
==
shape_to
[
-
1
]:
var_to
.
assign
(
expand_1_axis
(
var_from
.
numpy
(),
epsilon
=
epsilon
,
axis
=-
1
))
elif
shape_from
[
0
]
*
2
==
shape_to
[
0
]
and
shape_from
[
-
1
]
*
2
==
shape_to
[
-
1
]:
var_to
.
assign
(
expand_2_axes
(
var_from
.
numpy
(),
epsilon
=
epsilon
))
else
:
raise
ValueError
(
"Shape not supported, {}, {}"
.
format
(
shape_from
,
shape_to
))
def
model_to_model_2x_wide
(
model_from
:
tf
.
Module
,
model_to
:
tf
.
Module
,
epsilon
:
float
=
0.1
):
"""Expands a model to a wider version.
Also makes sure that the output of the model is not changed after expanding.
For example:
```
model_narrow = tf.keras.Sequential()
model_narrow.add(tf.keras.Input(shape=(3,)))
model_narrow.add(tf.keras.layers.Dense(4))
model_narrow.add(tf.keras.layers.Dense(1))
model_wide = tf.keras.Sequential()
model_wide.add(tf.keras.Input(shape=(6,)))
model_wide.add(tf.keras.layers.Dense(8))
model_wide.add(tf.keras.layers.Dense(1))
model_to_model_2x_wide(model_narrow, model_wide)
assert model_narrow([[1, 2, 3]]) == model_wide([[1, 1, 2, 2, 3, 3]])
```
We assume that `model_from` and `model_to` has the same architecture and only
widths of them differ.
Args:
model_from: input model to expand.
model_to: output model whose variables will be assigned expanded values
according to `model_from`.
epsilon: the noise ratio that will be added, when splitting `var_from`.
"""
for
w_from
,
w_to
in
zip
(
model_from
.
trainable_variables
,
model_to
.
trainable_variables
):
logging
.
info
(
"expanding %s %s to %s %s"
,
w_from
.
name
,
w_from
.
shape
,
w_to
.
name
,
w_to
.
shape
)
var_to_var
(
w_from
,
w_to
,
epsilon
=
epsilon
)
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/experimental/tf2_utils_2x_wide_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf2_utils_2x_wide."""
import
numpy
as
np
import
tensorflow
as
tf
from
official.modeling.fast_training.experimental
import
tf2_utils_2x_wide
class
Tf2Utils2XWideTest
(
tf
.
test
.
TestCase
):
def
test_expand_vector
(
self
):
x
=
np
.
array
([
1
,
2
])
self
.
assertAllClose
(
tf2_utils_2x_wide
.
expand_vector
(
x
),
np
.
array
([
1
,
1
,
2
,
2
]))
def
test_expand_matrix
(
self
):
x
=
np
.
array
([[
1
,
2
],
[
3
,
4
]])
x
=
tf2_utils_2x_wide
.
expand_2_axes
(
x
,
epsilon
=
0.1
)
self
.
assertAllClose
(
x
[
0
,
:]
+
x
[
1
,
:],
np
.
array
([
1
,
1
,
2
,
2
]))
self
.
assertAllClose
(
x
[
2
,
:]
+
x
[
3
,
:],
np
.
array
([
3
,
3
,
4
,
4
]))
def
test_expand_matrix_axis_0
(
self
):
x
=
np
.
array
([[
1
,
2
],
[
3
,
4
]])
x
=
tf2_utils_2x_wide
.
expand_1_axis
(
x
,
axis
=
0
,
epsilon
=
0.1
)
self
.
assertAllClose
(
x
[
0
,
:]
+
x
[
1
,
:],
np
.
array
([
1
,
2
]))
self
.
assertAllClose
(
x
[
2
,
:]
+
x
[
3
,
:],
np
.
array
([
3
,
4
]))
def
test_expand_matrix_axis_1
(
self
):
x
=
np
.
array
([[
1
,
2
],
[
3
,
4
]])
x
=
tf2_utils_2x_wide
.
expand_1_axis
(
x
,
axis
=-
1
,
epsilon
=
0.1
)
self
.
assertAllClose
(
x
[:,
0
]
+
x
[:,
1
],
np
.
array
([
1
,
3
]))
self
.
assertAllClose
(
x
[:,
2
]
+
x
[:,
3
],
np
.
array
([
2
,
4
]))
def
test_expand_3d_tensor
(
self
):
x0
=
np
.
array
([
10
,
11
])
x1
=
np
.
array
([
10
,
10
,
11
,
11
])
w0
=
np
.
random
.
rand
(
2
,
2
)
w1
=
tf2_utils_2x_wide
.
expand_2_axes
(
w0
,
epsilon
=
0.1
)
o0
=
np
.
matmul
(
x0
,
w0
)
o1
=
np
.
matmul
(
x1
,
w1
)
self
.
assertAllClose
(
np
.
repeat
(
o0
,
2
,
axis
=-
1
),
o1
)
def
test_expand_3d_tensor_axis_0
(
self
):
x0
=
np
.
array
([
10
,
11
])
x1
=
np
.
array
([
10
,
10
,
11
,
11
])
w0
=
np
.
random
.
rand
(
2
,
2
)
w1
=
tf2_utils_2x_wide
.
expand_1_axis
(
w0
,
axis
=
0
,
epsilon
=
0.1
)
o0
=
np
.
matmul
(
x0
,
w0
)
o1
=
np
.
matmul
(
x1
,
w1
)
self
.
assertAllClose
(
o0
,
o1
)
def
test_expand_3d_tensor_axis_2
(
self
):
x
=
np
.
array
([
10
,
11
])
w0
=
np
.
random
.
rand
(
2
,
2
)
w1
=
tf2_utils_2x_wide
.
expand_1_axis
(
w0
,
axis
=-
1
,
epsilon
=
0.1
)
o0
=
np
.
matmul
(
x
,
w0
)
o1
=
np
.
matmul
(
x
,
w1
)
self
.
assertAllClose
(
o0
,
np
.
sum
(
o1
.
reshape
(
2
,
2
),
axis
=-
1
))
def
test_end_to_end
(
self
):
"""Covers expand_vector, expand_2_axes, and expand_1_axis."""
model_narrow
=
tf
.
keras
.
Sequential
()
model_narrow
.
add
(
tf
.
keras
.
Input
(
shape
=
(
3
,)))
model_narrow
.
add
(
tf
.
keras
.
layers
.
Dense
(
4
))
model_narrow
.
add
(
tf
.
keras
.
layers
.
Dense
(
4
))
model_narrow
.
add
(
tf
.
keras
.
layers
.
Dense
(
1
))
model_wide
=
tf
.
keras
.
Sequential
()
model_wide
.
add
(
tf
.
keras
.
Input
(
shape
=
(
6
,)))
model_wide
.
add
(
tf
.
keras
.
layers
.
Dense
(
8
))
model_wide
.
add
(
tf
.
keras
.
layers
.
Dense
(
8
))
model_wide
.
add
(
tf
.
keras
.
layers
.
Dense
(
1
))
x0
=
np
.
array
([[
1
,
2
,
3
]])
x1
=
np
.
array
([[
1
,
1
,
2
,
2
,
3
,
3
]])
# Call model once to build variables first.
_
,
_
=
model_narrow
(
x0
),
model_wide
(
x1
)
tf2_utils_2x_wide
.
model_to_model_2x_wide
(
model_narrow
,
model_wide
,
epsilon
=
0.2
)
self
.
assertAllClose
(
model_narrow
(
x0
),
model_wide
(
x1
),
rtol
=
1e-05
,
atol
=
1e-05
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/policies.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base ProgressivePolicy definition for progressive training.
To write a progressive model, subclass ProgressivePolicy and implement its
abstract methods to handle each training stage.
"""
import
abc
import
dataclasses
from
typing
import
Any
,
Mapping
from
absl
import
logging
import
six
import
tensorflow
as
tf
from
official.common
import
streamz_counters
from
official.modeling.fast_training.progressive
import
utils
from
official.modeling.hyperparams
import
base_config
@
dataclasses
.
dataclass
class
ProgressiveConfig
(
base_config
.
Config
):
pass
@
six
.
add_metaclass
(
abc
.
ABCMeta
)
class
ProgressivePolicy
:
"""The APIs for handling progressive training stages.
Attributes:
cur_model: The model for the current progressive training stage.
cur_train_dataset: The train dataset function for the current stage.
cur_eval_dataset: The eval dataset function for the current stage.
cur_optimizer: The optimizer for the current stage.
cur_checkpoint_items: Items to be saved in and restored from checkpoints,
for the progressive trainer.
is_last_stage: Whether it is currently in the last stage.
Interfaces:
is_stage_advancing: Returns if progressive training is advancing to the
next stage.
update_pt_stage: Update progressive training stage.
"""
def
__init__
(
self
):
"""Initialize stage policy."""
self
.
_cur_train_dataset
=
None
self
.
_cur_eval_dataset
=
None
self
.
_volatiles
=
utils
.
VolatileTrackable
(
optimizer
=
None
,
model
=
None
)
stage_id
=
0
self
.
_stage_id
=
tf
.
Variable
(
stage_id
,
trainable
=
False
,
dtype
=
tf
.
int64
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
,
shape
=
[])
self
.
_volatiles
.
reassign_trackable
(
optimizer
=
self
.
get_optimizer
(
stage_id
),
model
=
self
.
get_model
(
stage_id
,
old_model
=
None
))
# pytype: disable=wrong-arg-types # typed-keras
streamz_counters
.
progressive_policy_creation_counter
.
get_cell
(
).
increase_by
(
1
)
def
compute_stage_id
(
self
,
global_step
:
int
)
->
int
:
for
stage_id
in
range
(
self
.
num_stages
()):
global_step
-=
self
.
num_steps
(
stage_id
)
if
global_step
<
0
:
return
stage_id
logging
.
error
(
'Global step %d found no matching progressive stages. '
'Default to the last stage.'
,
global_step
)
return
self
.
num_stages
()
-
1
@
abc
.
abstractmethod
def
num_stages
(
self
)
->
int
:
"""Return the total number of progressive stages."""
pass
@
abc
.
abstractmethod
def
num_steps
(
self
,
stage_id
:
int
)
->
int
:
"""Return the total number of steps in this stage."""
pass
@
abc
.
abstractmethod
def
get_model
(
self
,
stage_id
:
int
,
old_model
:
tf
.
keras
.
Model
=
None
)
->
tf
.
keras
.
Model
:
# pytype: disable=annotation-type-mismatch # typed-keras
"""Return model for this stage. For initialization, `old_model` = None."""
pass
@
abc
.
abstractmethod
def
get_optimizer
(
self
,
stage_id
:
int
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
"""Return optimizer for this stage."""
pass
@
abc
.
abstractmethod
def
get_train_dataset
(
self
,
stage_id
:
int
)
->
tf
.
data
.
Dataset
:
"""Return training Dataset for this stage."""
pass
@
abc
.
abstractmethod
def
get_eval_dataset
(
self
,
stage_id
:
int
)
->
tf
.
data
.
Dataset
:
"""Return evaluation Dataset for this stage."""
pass
@
property
def
cur_model
(
self
)
->
tf
.
keras
.
Model
:
return
self
.
_volatiles
.
model
@
property
def
cur_train_dataset
(
self
)
->
tf
.
data
.
Dataset
:
if
self
.
_cur_train_dataset
is
None
:
self
.
_cur_train_dataset
=
self
.
get_train_dataset
(
self
.
_stage_id
.
numpy
())
return
self
.
_cur_train_dataset
@
property
def
cur_eval_dataset
(
self
)
->
tf
.
data
.
Dataset
:
if
self
.
_cur_eval_dataset
is
None
:
self
.
_cur_eval_dataset
=
self
.
get_eval_dataset
(
self
.
_stage_id
.
numpy
())
return
self
.
_cur_eval_dataset
@
property
def
cur_optimizer
(
self
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
return
self
.
_volatiles
.
optimizer
@
property
def
is_last_stage
(
self
)
->
bool
:
stage_id
=
self
.
_stage_id
.
numpy
()
return
stage_id
>=
self
.
num_stages
()
-
1
@
property
def
cur_checkpoint_items
(
self
)
->
Mapping
[
str
,
Any
]:
return
dict
(
stage_id
=
self
.
_stage_id
,
volatiles
=
self
.
_volatiles
)
def
is_stage_advancing
(
self
,
global_step
:
int
)
->
bool
:
old_stage_id
=
self
.
_stage_id
.
numpy
()
new_stage_id
=
self
.
compute_stage_id
(
global_step
)
return
old_stage_id
!=
new_stage_id
def
update_pt_stage
(
self
,
global_step
:
int
,
pass_old_model
=
True
)
->
None
:
"""Update progressive training internal status.
Call this after a training loop ends.
Args:
global_step: an integer scalar of the current global step.
pass_old_model: whether to pass the old_model to get_model() function.
This is set to False if the old_model is irrelevant (e.g, just a default
model from stage 0).
"""
old_stage_id
=
self
.
_stage_id
.
numpy
()
new_stage_id
=
self
.
compute_stage_id
(
global_step
)
logging
.
info
(
'Switching stage from %d to %d'
,
old_stage_id
,
new_stage_id
)
# Update stage id.
self
.
_stage_id
.
assign
(
new_stage_id
)
# Update dataset function.
self
.
_cur_train_dataset
=
None
self
.
_cur_eval_dataset
=
None
# Update optimizer and model.
new_optimizer
=
self
.
get_optimizer
(
new_stage_id
)
self
.
_volatiles
.
reassign_trackable
(
optimizer
=
new_optimizer
)
new_model
=
self
.
get_model
(
new_stage_id
,
old_model
=
self
.
cur_model
if
pass_old_model
else
None
)
self
.
_volatiles
.
reassign_trackable
(
model
=
new_model
)
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/train.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM binary for the progressive trainer."""
from
absl
import
app
from
absl
import
flags
import
gin
from
official.common
import
distribute_utils
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling.fast_training.progressive
import
train_lib
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
,
**
params
.
runtime
.
model_parallelism
())
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/train_lib.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TFM progressive training driver library.
Compared to the common training driver, the only difference is that we use
prog_trainer_lib.ProgressiveTrainer instead of the base trainer.
"""
# pytype: disable=attribute-error
import
os
from
typing
import
Any
,
Mapping
,
Tuple
# Import libraries
from
absl
import
logging
import
orbit
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.core
import
train_lib
as
base_train_lib
from
official.modeling.fast_training.progressive
import
trainer
as
prog_trainer_lib
def
run_experiment
(
distribution_strategy
:
tf
.
distribute
.
Strategy
,
task
:
base_task
.
Task
,
mode
:
str
,
params
:
config_definitions
.
ExperimentConfig
,
model_dir
:
str
,
run_post_eval
:
bool
=
False
,
save_summary
:
bool
=
True
)
\
->
Tuple
[
tf
.
keras
.
Model
,
Mapping
[
str
,
Any
]]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A Task instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
Returns:
A 2-tuple of (model, eval_logs).
model: `tf.keras.Model` instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
"""
with
distribution_strategy
.
scope
():
logging
.
info
(
'Running progressive trainer.'
)
trainer
=
prog_trainer_lib
.
ProgressiveTrainer
(
params
,
task
,
ckpt_dir
=
model_dir
,
train
=
'train'
in
mode
,
evaluate
=
(
'eval'
in
mode
)
or
run_post_eval
,
checkpoint_exporter
=
base_train_lib
.
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
if
trainer
.
checkpoint
:
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
trainer
.
checkpoint
,
directory
=
model_dir
,
max_to_keep
=
params
.
trainer
.
max_to_keep
,
step_counter
=
trainer
.
global_step
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
init_fn
=
trainer
.
initialize
)
else
:
checkpoint_manager
=
None
controller
=
orbit
.
Controller
(
strategy
=
distribution_strategy
,
trainer
=
trainer
if
'train'
in
mode
else
None
,
evaluator
=
trainer
,
global_step
=
trainer
.
global_step
,
steps_per_loop
=
params
.
trainer
.
steps_per_loop
,
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
model_dir
,
'train'
)
if
(
save_summary
)
else
None
,
eval_summary_dir
=
os
.
path
.
join
(
model_dir
,
'validation'
)
if
(
save_summary
)
else
None
,
summary_interval
=
params
.
trainer
.
summary_interval
if
(
save_summary
)
else
None
)
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
with
distribution_strategy
.
scope
():
if
mode
==
'train'
:
controller
.
train
(
steps
=
params
.
trainer
.
train_steps
)
elif
mode
==
'train_and_eval'
:
controller
.
train_and_evaluate
(
train_steps
=
params
.
trainer
.
train_steps
,
eval_steps
=
params
.
trainer
.
validation_steps
,
eval_interval
=
params
.
trainer
.
validation_interval
)
elif
mode
==
'eval'
:
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
elif
mode
==
'continuous_eval'
:
def
timeout_fn
():
if
trainer
.
global_step
.
numpy
()
>=
params
.
trainer
.
train_steps
:
return
True
return
False
controller
.
evaluate_continuously
(
steps
=
params
.
trainer
.
validation_steps
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
)
else
:
raise
NotImplementedError
(
'The mode is not implemented: %s'
%
mode
)
if
run_post_eval
:
with
distribution_strategy
.
scope
():
return
trainer
.
model
,
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
params
.
trainer
.
validation_steps
))
else
:
return
trainer
.
model
,
{}
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/train_lib_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the progressive train_lib."""
import
os
from
absl
import
flags
from
absl.testing
import
parameterized
import
dataclasses
import
orbit
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.common
import
flags
as
tfm_flags
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: enable=unused-import
from
official.core
import
config_definitions
as
cfg
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.fast_training.progressive
import
policies
from
official.modeling.fast_training.progressive
import
train_lib
from
official.modeling.fast_training.progressive
import
trainer
as
prog_trainer_lib
from
official.utils.testing
import
mock_task
FLAGS
=
flags
.
FLAGS
tfm_flags
.
define_flags
()
@
dataclasses
.
dataclass
class
ProgTaskConfig
(
cfg
.
TaskConfig
):
pass
@
task_factory
.
register_task_cls
(
ProgTaskConfig
)
class
ProgMockTask
(
policies
.
ProgressivePolicy
,
mock_task
.
MockTask
):
"""Progressive task for testing."""
def
__init__
(
self
,
params
:
cfg
.
TaskConfig
,
logging_dir
:
str
=
None
):
mock_task
.
MockTask
.
__init__
(
self
,
params
=
params
,
logging_dir
=
logging_dir
)
policies
.
ProgressivePolicy
.
__init__
(
self
)
def
num_stages
(
self
):
return
2
def
num_steps
(
self
,
stage_id
):
return
2
if
stage_id
==
0
else
4
def
get_model
(
self
,
stage_id
,
old_model
=
None
):
del
stage_id
,
old_model
return
self
.
build_model
()
def
get_optimizer
(
self
,
stage_id
):
"""Build optimizer for each stage."""
params
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'initial_learning_rate'
:
0.01
,
'end_learning_rate'
:
0.0
,
'power'
:
1.0
,
'decay_steps'
:
10
,
},
},
'warmup'
:
{
'polynomial'
:
{
'power'
:
1
,
'warmup_steps'
:
2
,
},
'type'
:
'polynomial'
,
}
})
opt_factory
=
optimization
.
OptimizerFactory
(
params
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
return
optimizer
def
get_train_dataset
(
self
,
stage_id
):
del
stage_id
strategy
=
tf
.
distribute
.
get_strategy
()
return
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
self
.
build_inputs
,
None
)
def
get_eval_dataset
(
self
,
stage_id
):
del
stage_id
strategy
=
tf
.
distribute
.
get_strategy
()
return
orbit
.
utils
.
make_distributed_dataset
(
strategy
,
self
.
build_inputs
,
None
)
class
TrainTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
TrainTest
,
self
).
setUp
()
self
.
_test_config
=
{
'trainer'
:
{
'checkpoint_interval'
:
10
,
'steps_per_loop'
:
10
,
'summary_interval'
:
10
,
'train_steps'
:
10
,
'validation_steps'
:
5
,
'validation_interval'
:
10
,
'continuous_eval_timeout'
:
1
,
'optimizer_config'
:
{
'optimizer'
:
{
'type'
:
'sgd'
,
},
'learning_rate'
:
{
'type'
:
'constant'
}
}
},
}
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
flag_mode
=
[
'train'
,
'eval'
,
'train_and_eval'
],
run_post_eval
=
[
True
,
False
]))
def
test_end_to_end
(
self
,
distribution_strategy
,
flag_mode
,
run_post_eval
):
model_dir
=
self
.
get_temp_dir
()
experiment_config
=
cfg
.
ExperimentConfig
(
trainer
=
prog_trainer_lib
.
ProgressiveTrainerConfig
(),
task
=
ProgTaskConfig
())
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
,
self
.
_test_config
,
is_strict
=
False
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
experiment_config
.
task
,
logging_dir
=
model_dir
)
_
,
logs
=
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
experiment_config
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
)
if
run_post_eval
:
self
.
assertNotEmpty
(
logs
)
else
:
self
.
assertEmpty
(
logs
)
if
flag_mode
==
'eval'
:
return
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
model_dir
,
'checkpoint'
)))
# Tests continuous evaluation.
_
,
logs
=
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
'continuous_eval'
,
params
=
experiment_config
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
)
print
(
logs
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/trainer.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Progressive Trainer implementation.
The trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
import
dataclasses
import
os
from
typing
import
Any
,
Optional
# Import libraries
from
absl
import
logging
import
gin
import
orbit
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
base_trainer
as
trainer_lib
from
official.core
import
config_definitions
from
official.modeling.fast_training.progressive
import
policies
from
official.modeling.fast_training.progressive
import
utils
ExperimentConfig
=
config_definitions
.
ExperimentConfig
@
dataclasses
.
dataclass
class
ProgressiveTrainerConfig
(
config_definitions
.
TrainerConfig
):
"""Configuration for progressive trainer.
Attributes:
progressive: A task-specific config. Users can subclass ProgressiveConfig
and define any task-specific settings in their subclass.
export_checkpoint: A bool. Whether to export checkpoints in non-progressive
manner (without the volatiles wrapper) such that your down-stream tasks
can load checkpoints from a progressive trainer as if it is a regular
checkpoint.
export_checkpoint_interval: A bool. The number of steps between exporting
checkpoints. If None (by default), will use the same value as
TrainerConfig.checkpoint_interval.
export_max_to_keep: The maximum number of exported checkpoints to keep.
If None (by default), will use the same value as
TrainerConfig.max_to_keep.
export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
during the final progressive training stage. In other words, whether to
not export small, partial models. In many cases, it is not meaningful to
finetune a small, partial model in down-stream tasks.
"""
progressive
:
Optional
[
policies
.
ProgressiveConfig
]
=
None
export_checkpoint
:
bool
=
True
export_checkpoint_interval
:
Optional
[
int
]
=
None
export_max_to_keep
:
Optional
[
int
]
=
None
export_only_final_stage_ckpt
:
bool
=
True
@
gin
.
configurable
class
ProgressiveTrainer
(
trainer_lib
.
Trainer
):
"""Implements the progressive trainer shared for TensorFlow models."""
def
__init__
(
self
,
config
:
ExperimentConfig
,
prog_task
:
base_task
.
Task
,
# also implemented ProgressivePolicy.
ckpt_dir
:
str
=
''
,
train
:
bool
=
True
,
evaluate
:
bool
=
True
,
checkpoint_exporter
:
Any
=
None
):
"""Initialize common trainer for TensorFlow models.
Args:
config: An `ExperimentConfig` instance specifying experiment config.
prog_task: An instance both implemented policies.ProgressivePolicy and
base_task.Task.
ckpt_dir: Checkpoint directory.
train: bool, whether or not this trainer will be used for training.
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_config
=
config
self
.
_runtime_options
=
trainer_lib
.
get_runtime_options
(
config
)
self
.
_task
=
prog_task
# Directory for non-progressive checkpoint
self
.
_export_ckpt_dir
=
os
.
path
.
join
(
ckpt_dir
,
'exported_ckpts'
)
tf
.
io
.
gfile
.
makedirs
(
self
.
_export_ckpt_dir
)
self
.
_export_ckpt_manager
=
None
# Receive other checkpoint export, e.g, best checkpoint exporter.
# TODO(lehou): unify the checkpoint exporting logic, although the default
# setting does not use checkpoint_exporter.
self
.
_checkpoint_exporter
=
checkpoint_exporter
self
.
_global_step
=
orbit
.
utils
.
create_global_step
()
self
.
_checkpoint
=
utils
.
CheckpointWithHooks
(
before_load_hook
=
self
.
_update_pt_stage_from_ckpt
,
global_step
=
self
.
global_step
,
**
self
.
_task
.
cur_checkpoint_items
)
self
.
_train_loss
=
tf
.
keras
.
metrics
.
Mean
(
'training_loss'
,
dtype
=
tf
.
float32
)
self
.
_validation_loss
=
tf
.
keras
.
metrics
.
Mean
(
'validation_loss'
,
dtype
=
tf
.
float32
)
self
.
_train_metrics
=
self
.
task
.
build_metrics
(
training
=
True
)
+
self
.
model
.
metrics
self
.
_validation_metrics
=
self
.
task
.
build_metrics
(
training
=
False
)
+
self
.
model
.
metrics
if
train
:
orbit
.
StandardTrainer
.
__init__
(
self
,
None
,
# Manage train_dataset by ourselves, not by StandardTrainer.
options
=
orbit
.
StandardTrainerOptions
(
use_tf_while_loop
=
config
.
trainer
.
train_tf_while_loop
,
use_tf_function
=
config
.
trainer
.
train_tf_function
))
if
evaluate
:
orbit
.
StandardEvaluator
.
__init__
(
self
,
None
,
# Manage train_dataset by ourselves, not by StandardEvaluator.
options
=
orbit
.
StandardEvaluatorOptions
(
use_tf_function
=
config
.
trainer
.
eval_tf_function
))
@
property
def
model
(
self
):
return
self
.
_task
.
cur_model
@
property
def
optimizer
(
self
):
return
self
.
_task
.
cur_optimizer
# override
@
property
def
train_dataset
(
self
):
"""Overriding StandardTrainer.train_dataset."""
return
self
.
_task
.
cur_train_dataset
# override
@
train_dataset
.
setter
def
train_dataset
(
self
,
_
):
raise
SyntaxError
(
'Please do not set train_dataset. Progressive training '
'relies on progressive policy to manager train dataset.'
)
# override
@
property
def
eval_dataset
(
self
):
"""Overriding StandardEvaluator.eval_dataset."""
return
self
.
_task
.
cur_eval_dataset
# override
@
eval_dataset
.
setter
def
eval_dataset
(
self
,
_
):
raise
SyntaxError
(
'Please do not set eval_dataset. Progressive training '
'relies on progressive policy to manager eval dataset.'
)
def
train_loop_end
(
self
):
"""See base class."""
logs
=
{}
for
metric
in
self
.
train_metrics
+
[
self
.
train_loss
]:
logs
[
metric
.
name
]
=
metric
.
result
()
metric
.
reset_states
()
if
callable
(
self
.
optimizer
.
learning_rate
):
logs
[
'learning_rate'
]
=
self
.
optimizer
.
learning_rate
(
self
.
optimizer
.
iterations
)
else
:
logs
[
'learning_rate'
]
=
self
.
optimizer
.
learning_rate
self
.
_maybe_export_non_progressive_checkpoint
(
self
.
_export_ckpt_dir
)
if
self
.
_task
.
is_stage_advancing
(
self
.
global_step
.
numpy
()):
old_train_dataset
=
self
.
train_dataset
# Update progressive properties
self
.
_task
.
update_pt_stage
(
self
.
global_step
.
numpy
())
# Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
# rebuild the train and eval functions with the updated model.
self
.
_train_loop_fn
=
None
self
.
_eval_loop_fn
=
None
if
self
.
train_dataset
!=
old_train_dataset
:
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self
.
_train_iter
=
None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self
.
_export_ckpt_manager
=
None
return
logs
def
_update_pt_stage_from_ckpt
(
self
,
ckpt_file
):
"""Update stage properties based on the global_step variable in a ckpt file.
Before loading variables from a checkpoint file, we need to go to the
correct stage and build corresponding model and optimizer, to make sure that
we retore variables of the right model and optimizer.
Args:
ckpt_file: Checkpoint file that will be restored/read from.
"""
if
not
ckpt_file
:
return
ckpt
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
)
ckpt
.
read
(
ckpt_file
).
expect_partial
().
assert_existing_objects_matched
()
if
self
.
_task
.
is_stage_advancing
(
self
.
global_step
.
numpy
()):
old_train_dataset
=
self
.
train_dataset
# Update progressive properties
self
.
_task
.
update_pt_stage
(
self
.
global_step
.
numpy
(),
pass_old_model
=
False
)
# Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
# rebuild the train and eval functions with the updated model.
self
.
_train_loop_fn
=
None
self
.
_eval_loop_fn
=
None
if
self
.
train_dataset
!=
old_train_dataset
:
# Setting `self._train_iter` to None will rebuild the dataset iterator.
self
.
_train_iter
=
None
# Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
# for exporting.
self
.
_export_ckpt_manager
=
None
def
_maybe_export_non_progressive_checkpoint
(
self
,
export_ckpt_dir
):
"""Export checkpoints in non-progressive format.
This basically removes the wrapping of self._task.cur_checkpoint_items
-- just save the model, optimizer, etc., directly.
The purpose is to let your down-stream tasks to use these checkpoints.
Args:
export_ckpt_dir: A str. folder of exported checkpoints.
"""
if
not
self
.
config
.
trainer
.
export_checkpoint
:
logging
.
info
(
'Not exporting checkpoints.'
)
return
if
not
self
.
_task
.
is_last_stage
and
(
self
.
config
.
trainer
.
export_only_final_stage_ckpt
):
logging
.
info
(
'Not exporting checkpoints until the last stage.'
)
return
if
self
.
_export_ckpt_manager
is
None
:
# Create a checkpoint object just now, to make sure we use
# progressive_policy.cur_model and progressive_policy.cur_optimizer of the
# current stage.
if
hasattr
(
self
.
model
,
'checkpoint_items'
):
checkpoint_items
=
self
.
model
.
checkpoint_items
else
:
checkpoint_items
=
{}
checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
model
=
self
.
model
,
optimizer
=
self
.
optimizer
,
**
checkpoint_items
)
max_to_keep
=
self
.
config
.
trainer
.
export_max_to_keep
or
(
self
.
config
.
trainer
.
max_to_keep
)
checkpoint_interval
=
self
.
config
.
trainer
.
export_checkpoint_interval
or
(
self
.
config
.
trainer
.
checkpoint_interval
)
self
.
_export_ckpt_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
export_ckpt_dir
,
checkpoint_name
=
'ckpt'
,
step_counter
=
self
.
global_step
,
max_to_keep
=
max_to_keep
,
checkpoint_interval
=
checkpoint_interval
,
)
# Make sure we export the last checkpoint.
last_checkpoint
=
(
self
.
global_step
.
numpy
()
==
self
.
_config
.
trainer
.
train_steps
)
checkpoint_path
=
self
.
_export_ckpt_manager
.
save
(
checkpoint_number
=
self
.
global_step
.
numpy
(),
check_interval
=
not
last_checkpoint
)
if
checkpoint_path
:
logging
.
info
(
'Checkpoints exported: %s.'
,
checkpoint_path
)
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/trainer_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the progressive trainer."""
# pylint: disable=g-direct-tensorflow-import
import
os
from
absl.testing
import
parameterized
import
orbit
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
optimization
from
official.modeling.fast_training.progressive
import
policies
from
official.modeling.fast_training.progressive
import
trainer
as
trainer_lib
from
official.nlp.configs
import
bert
from
official.utils.testing
import
mock_task
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],)
def
get_exp_config
():
return
cfg
.
ExperimentConfig
(
task
=
cfg
.
TaskConfig
(
model
=
bert
.
PretrainerConfig
()),
trainer
=
trainer_lib
.
ProgressiveTrainerConfig
(
export_checkpoint
=
True
,
export_checkpoint_interval
=
1
,
export_only_final_stage_ckpt
=
False
))
class
TestPolicy
(
policies
.
ProgressivePolicy
,
mock_task
.
MockTask
):
"""Just for testing purposes."""
def
__init__
(
self
,
strategy
,
task_config
,
change_train_dataset
=
True
):
self
.
_strategy
=
strategy
self
.
_change_train_dataset
=
change_train_dataset
self
.
_my_train_dataset
=
None
mock_task
.
MockTask
.
__init__
(
self
,
params
=
task_config
,
logging_dir
=
None
)
policies
.
ProgressivePolicy
.
__init__
(
self
)
def
num_stages
(
self
)
->
int
:
return
2
def
num_steps
(
self
,
stage_id
:
int
)
->
int
:
return
2
if
stage_id
==
0
else
4
def
get_model
(
self
,
stage_id
:
int
,
old_model
:
tf
.
keras
.
Model
)
->
tf
.
keras
.
Model
:
del
stage_id
,
old_model
return
self
.
build_model
()
def
get_optimizer
(
self
,
stage_id
:
int
)
->
tf
.
keras
.
optimizers
.
Optimizer
:
optimizer_type
=
'sgd'
if
stage_id
==
0
else
'adamw'
optimizer_config
=
cfg
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
optimizer_type
},
'learning_rate'
:
{
'type'
:
'constant'
}})
opt_factory
=
optimization
.
OptimizerFactory
(
optimizer_config
)
return
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
def
get_train_dataset
(
self
,
stage_id
:
int
)
->
tf
.
data
.
Dataset
:
if
not
self
.
_change_train_dataset
and
self
.
_my_train_dataset
:
return
self
.
_my_train_dataset
if
self
.
_strategy
:
self
.
_my_train_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
_strategy
,
self
.
_build_inputs
,
stage_id
)
else
:
self
.
_my_train_dataset
=
self
.
_build_inputs
(
stage_id
)
return
self
.
_my_train_dataset
def
get_eval_dataset
(
self
,
stage_id
:
int
)
->
tf
.
data
.
Dataset
:
if
self
.
_strategy
:
return
orbit
.
utils
.
make_distributed_dataset
(
self
.
_strategy
,
self
.
_build_inputs
,
stage_id
)
return
self
.
_build_inputs
(
stage_id
)
def
_build_inputs
(
self
,
stage_id
):
def
dummy_data
(
_
):
batch_size
=
2
if
stage_id
==
0
else
1
x
=
tf
.
zeros
(
shape
=
(
batch_size
,
2
),
dtype
=
tf
.
float32
)
label
=
tf
.
zeros
(
shape
=
(
batch_size
,
1
),
dtype
=
tf
.
float32
)
return
x
,
label
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
return
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
class
TrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
TrainerTest
,
self
).
setUp
()
self
.
_config
=
get_exp_config
()
def
create_test_trainer
(
self
,
distribution
,
model_dir
,
change_train_dataset
):
trainer
=
trainer_lib
.
ProgressiveTrainer
(
self
.
_config
,
prog_task
=
TestPolicy
(
distribution
,
self
.
_config
.
task
,
change_train_dataset
),
ckpt_dir
=
model_dir
)
return
trainer
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_checkpointing
(
self
,
distribution
):
model_dir
=
self
.
get_temp_dir
()
ckpt_file
=
os
.
path
.
join
(
model_dir
,
'ckpt'
)
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
distribution
,
model_dir
,
True
)
self
.
assertFalse
(
trainer
.
_task
.
is_last_stage
)
trainer
.
train
(
tf
.
convert_to_tensor
(
4
,
dtype
=
tf
.
int32
))
self
.
assertTrue
(
trainer
.
_task
.
is_last_stage
)
trainer
.
checkpoint
.
save
(
ckpt_file
)
trainer
=
self
.
create_test_trainer
(
distribution
,
model_dir
,
True
)
self
.
assertFalse
(
trainer
.
_task
.
is_last_stage
)
trainer
.
checkpoint
.
restore
(
ckpt_file
+
'-1'
)
self
.
assertTrue
(
trainer
.
_task
.
is_last_stage
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_train_dataset
(
self
,
distribution
):
model_dir
=
self
.
get_temp_dir
()
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
distribution
,
model_dir
,
True
)
# Using dataset of stage == 0
train_iter
=
tf
.
nest
.
map_structure
(
iter
,
trainer
.
train_dataset
)
train_data
=
train_iter
.
next
()[
0
]
if
distribution
.
num_replicas_in_sync
>
1
:
train_data
=
train_data
.
values
[
0
]
self
.
assertEqual
(
train_data
.
shape
[
0
],
2
)
trainer
.
train
(
tf
.
convert_to_tensor
(
4
,
dtype
=
tf
.
int32
))
# Using dataset of stage == 1
train_iter
=
tf
.
nest
.
map_structure
(
iter
,
trainer
.
train_dataset
)
train_data
=
train_iter
.
next
()[
0
]
if
distribution
.
num_replicas_in_sync
>
1
:
train_data
=
train_data
.
values
[
0
]
self
.
assertEqual
(
train_data
.
shape
[
0
],
1
)
with
self
.
assertRaises
(
SyntaxError
):
trainer
.
train_dataset
=
None
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_train_dataset_no_switch
(
self
,
distribution
):
model_dir
=
self
.
get_temp_dir
()
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
distribution
,
model_dir
,
False
)
trainer
.
train
(
tf
.
convert_to_tensor
(
2
,
dtype
=
tf
.
int32
))
# _train_iter is not reset since the dataset is not changed.
self
.
assertIsNotNone
(
trainer
.
_train_iter
)
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
distribution
,
model_dir
,
True
)
trainer
.
train
(
tf
.
convert_to_tensor
(
2
,
dtype
=
tf
.
int32
))
# _train_iter is reset since the dataset changed.
self
.
assertIsNone
(
trainer
.
_train_iter
)
class
TrainerWithMaskedLMTaskTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
TrainerWithMaskedLMTaskTest
,
self
).
setUp
()
self
.
_config
=
get_exp_config
()
def
create_test_trainer
(
self
,
distribution
):
trainer
=
trainer_lib
.
ProgressiveTrainer
(
self
.
_config
,
prog_task
=
TestPolicy
(
distribution
,
self
.
_config
.
task
),
ckpt_dir
=
self
.
get_temp_dir
())
return
trainer
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_train
(
self
,
distribution
):
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
distribution
)
logs
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
logs
)
self
.
assertIn
(
'learning_rate'
,
logs
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_validate
(
self
,
distribution
):
with
distribution
.
scope
():
trainer
=
self
.
create_test_trainer
(
distribution
)
logs
=
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'validation_loss'
,
logs
)
self
.
assertEqual
(
logs
[
'counter'
],
5.
*
distribution
.
num_replicas_in_sync
)
@
combinations
.
generate
(
combinations
.
combine
(
mixed_precision_dtype
=
[
'float32'
,
'bfloat16'
,
'float16'
],
loss_scale
=
[
None
,
'dynamic'
,
128
,
256
],
))
def
test_configure_optimizer
(
self
,
mixed_precision_dtype
,
loss_scale
):
config
=
cfg
.
ExperimentConfig
(
task
=
cfg
.
TaskConfig
(
model
=
bert
.
PretrainerConfig
()),
runtime
=
cfg
.
RuntimeConfig
(
mixed_precision_dtype
=
mixed_precision_dtype
,
loss_scale
=
loss_scale
),
trainer
=
trainer_lib
.
ProgressiveTrainerConfig
(
export_checkpoint
=
True
,
export_checkpoint_interval
=
1
,
export_only_final_stage_ckpt
=
False
))
task
=
TestPolicy
(
None
,
config
.
task
)
trainer
=
trainer_lib
.
ProgressiveTrainer
(
config
,
task
,
self
.
get_temp_dir
())
if
mixed_precision_dtype
!=
'float16'
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
elif
mixed_precision_dtype
==
'float16'
and
loss_scale
is
None
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
metrics
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
metrics
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/fast_training/progressive/utils.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Util classes and functions."""
from
absl
import
logging
import
tensorflow
as
tf
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.training.tracking
import
tracking
class
VolatileTrackable
(
tracking
.
AutoTrackable
):
"""A util class to keep Trackables that might change instances."""
def
__init__
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
setattr
(
self
,
k
,
v
)
def
reassign_trackable
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
delattr
(
self
,
k
)
# untrack this object
setattr
(
self
,
k
,
v
)
# track the new object
class
CheckpointWithHooks
(
tf
.
train
.
Checkpoint
):
"""Same as tf.train.Checkpoint but supports hooks.
In progressive training, use this class instead of tf.train.Checkpoint.
Since the network architecture changes during progressive training, we need to
prepare something (like switch to the correct architecture) before loading the
checkpoint. This class supports a hook that will be executed before checkpoint
loading.
"""
def
__init__
(
self
,
before_load_hook
,
**
kwargs
):
self
.
_before_load_hook
=
before_load_hook
super
(
CheckpointWithHooks
,
self
).
__init__
(
**
kwargs
)
# override
def
read
(
self
,
save_path
,
options
=
None
):
self
.
_before_load_hook
(
save_path
)
logging
.
info
(
'Ran before_load_hook.'
)
super
(
CheckpointWithHooks
,
self
).
read
(
save_path
=
save_path
,
options
=
options
)
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/grad_utils.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Some gradient util functions to help users writing custom training loop."""
from
absl
import
logging
import
tensorflow
as
tf
def
_filter_grads
(
grads_and_vars
):
"""Filter out iterable with grad equal to None."""
grads_and_vars
=
tuple
(
grads_and_vars
)
if
not
grads_and_vars
:
return
grads_and_vars
filtered
=
[]
vars_with_empty_grads
=
[]
for
grad
,
var
in
grads_and_vars
:
if
grad
is
None
:
vars_with_empty_grads
.
append
(
var
)
else
:
filtered
.
append
((
grad
,
var
))
filtered
=
tuple
(
filtered
)
if
not
filtered
:
raise
ValueError
(
"No gradients provided for any variable: %s."
%
([
v
.
name
for
_
,
v
in
grads_and_vars
],))
if
vars_with_empty_grads
:
logging
.
warning
(
(
"Gradients do not exist for variables %s when minimizing the loss."
),
([
v
.
name
for
v
in
vars_with_empty_grads
]))
return
filtered
def
_filter_and_allreduce_gradients
(
grads_and_vars
,
allreduce_precision
=
"float32"
,
bytes_per_pack
=
0
):
"""Filter None grads and then allreduce gradients in specified precision.
This utils function is used when users intent to explicitly allreduce
gradients and customize gradients operations before and after allreduce.
The allreduced gradients are then passed to optimizer.apply_gradients(
experimental_aggregate_gradients=False).
Args:
grads_and_vars: gradients and variables pairs.
allreduce_precision: Whether to allreduce gradients in float32 or float16.
bytes_per_pack: A non-negative integer. Breaks collective operations into
packs of certain size. If it's zero, all gradients are in one pack.
Returns:
pairs of allreduced non-None gradients and variables.
"""
filtered_grads_and_vars
=
_filter_grads
(
grads_and_vars
)
(
grads
,
variables
)
=
zip
(
*
filtered_grads_and_vars
)
if
allreduce_precision
==
"float16"
:
grads
=
[
tf
.
cast
(
grad
,
"float16"
)
for
grad
in
grads
]
hints
=
tf
.
distribute
.
experimental
.
CommunicationOptions
(
bytes_per_pack
=
bytes_per_pack
)
allreduced_grads
=
tf
.
distribute
.
get_strategy
(
# pylint: disable=protected-access
).
extended
.
_replica_ctx_all_reduce
(
tf
.
distribute
.
ReduceOp
.
SUM
,
grads
,
hints
)
if
allreduce_precision
==
"float16"
:
allreduced_grads
=
[
tf
.
cast
(
grad
,
"float32"
)
for
grad
in
allreduced_grads
]
return
allreduced_grads
,
variables
def
_run_callbacks
(
callbacks
,
grads_and_vars
):
for
callback
in
callbacks
:
grads_and_vars
=
callback
(
grads_and_vars
)
return
grads_and_vars
def
minimize_using_explicit_allreduce
(
tape
,
optimizer
,
loss
,
trainable_variables
,
pre_allreduce_callbacks
=
None
,
post_allreduce_callbacks
=
None
,
allreduce_bytes_per_pack
=
0
):
"""Minimizes loss for one step by updating `trainable_variables`.
Minimizes loss for one step by updating `trainable_variables`.
This explicitly performs gradient allreduce, instead of relying on implicit
allreduce in optimizer.apply_gradients(). If training using FP16 mixed
precision, explicit allreduce will aggregate gradients in FP16 format.
For TPU and GPU training using FP32, explicit allreduce will aggregate
gradients in FP32 format.
Args:
tape: An instance of `tf.GradientTape`.
optimizer: An instance of `tf.keras.optimizers.Optimizer`.
loss: the loss tensor.
trainable_variables: A list of model Variables.
pre_allreduce_callbacks: A list of callback functions that takes gradients
and model variables pairs as input, manipulate them, and returns a new
gradients and model variables pairs. The callback functions will be
invoked in the list order and before gradients are allreduced. With
mixed precision training, the pre_allreduce_allbacks will be applied on
scaled_gradients. Default is no callbacks.
post_allreduce_callbacks: A list of callback functions that takes
gradients and model variables pairs as input, manipulate them, and
returns a new gradients and model variables paris. The callback
functions will be invoked in the list order and right before gradients
are applied to variables for updates. Default is no callbacks.
allreduce_bytes_per_pack: A non-negative integer. Breaks collective
operations into packs of certain size. If it's zero, all gradients are
in one pack.
"""
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
# FP16 GPU code path
with
tape
:
scaled_loss
=
optimizer
.
get_scaled_loss
(
loss
)
scaled_grads
=
tape
.
gradient
(
scaled_loss
,
trainable_variables
)
grads_and_vars
=
zip
(
scaled_grads
,
trainable_variables
)
if
pre_allreduce_callbacks
:
grads_and_vars
=
_run_callbacks
(
pre_allreduce_callbacks
,
grads_and_vars
)
(
allreduced_scaled_grads
,
filtered_training_vars
)
=
_filter_and_allreduce_gradients
(
grads_and_vars
,
allreduce_precision
=
"float16"
,
bytes_per_pack
=
allreduce_bytes_per_pack
)
allreduced_unscaled_grads
=
optimizer
.
get_unscaled_gradients
(
allreduced_scaled_grads
)
grads_and_vars
=
zip
(
allreduced_unscaled_grads
,
filtered_training_vars
)
else
:
# TPU or FP32 GPU code path
grads
=
tape
.
gradient
(
loss
,
trainable_variables
)
grads_and_vars
=
zip
(
grads
,
trainable_variables
)
if
pre_allreduce_callbacks
:
grads_and_vars
=
_run_callbacks
(
pre_allreduce_callbacks
,
grads_and_vars
)
(
allreduced_grads
,
filtered_training_vars
)
=
_filter_and_allreduce_gradients
(
grads_and_vars
,
allreduce_precision
=
"float32"
,
bytes_per_pack
=
allreduce_bytes_per_pack
)
grads_and_vars
=
zip
(
allreduced_grads
,
filtered_training_vars
)
if
post_allreduce_callbacks
:
grads_and_vars
=
_run_callbacks
(
post_allreduce_callbacks
,
grads_and_vars
)
optimizer
.
apply_gradients
(
grads_and_vars
,
experimental_aggregate_gradients
=
False
)
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/__init__.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hyperparams package definition."""
# pylint: disable=g-multiple-import
from
official.modeling.hyperparams.base_config
import
*
from
official.modeling.hyperparams.oneof
import
*
from
official.modeling.hyperparams.params_dict
import
*
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/base_config.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base configurations to standardize experiments."""
import
copy
import
dataclasses
import
functools
import
inspect
from
typing
import
Any
,
List
,
Mapping
,
Optional
,
Type
from
absl
import
logging
import
tensorflow
as
tf
import
yaml
from
official.modeling.hyperparams
import
params_dict
_BOUND
=
set
()
def
bind
(
config_cls
):
"""Bind a class to config cls."""
if
not
inspect
.
isclass
(
config_cls
):
raise
ValueError
(
'The bind decorator is supposed to apply on the class '
f
'attribute. Received
{
config_cls
}
, not a class.'
)
def
decorator
(
builder
):
if
config_cls
in
_BOUND
:
raise
ValueError
(
'Inside a program, we should not bind the config with a'
' class twice.'
)
if
inspect
.
isclass
(
builder
):
config_cls
.
_BUILDER
=
builder
# pylint: disable=protected-access
elif
inspect
.
isfunction
(
builder
):
def
_wrapper
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
return
builder
(
*
args
,
**
kwargs
)
config_cls
.
_BUILDER
=
_wrapper
# pylint: disable=protected-access
else
:
raise
ValueError
(
f
'The `BUILDER` type is not supported:
{
builder
}
'
)
_BOUND
.
add
(
config_cls
)
return
builder
return
decorator
@
dataclasses
.
dataclass
class
Config
(
params_dict
.
ParamsDict
):
"""The base configuration class that supports YAML/JSON based overrides.
Because of YAML/JSON serialization limitations, some semantics of dataclass
are not supported:
* It recursively enforces a allowlist of basic types and container types, so
it avoids surprises with copy and reuse caused by unanticipated types.
* Warning: it converts Dict to `Config` even within sequences,
e.g. for config = Config({'key': [([{'a': 42}],)]),
type(config.key[0][0][0]) is Config rather than dict.
If you define/annotate some field as Dict, the field will convert to a
`Config` instance and lose the dictionary type.
"""
# The class or method to bind with the params class.
_BUILDER
=
None
# It's safe to add bytes and other immutable types here.
IMMUTABLE_TYPES
=
(
str
,
int
,
float
,
bool
,
type
(
None
))
# It's safe to add set, frozenset and other collections here.
SEQUENCE_TYPES
=
(
list
,
tuple
)
default_params
:
dataclasses
.
InitVar
[
Optional
[
Mapping
[
str
,
Any
]]]
=
None
restrictions
:
dataclasses
.
InitVar
[
Optional
[
List
[
str
]]]
=
None
def
__post_init__
(
self
,
default_params
,
restrictions
):
super
().
__init__
(
default_params
=
default_params
,
restrictions
=
restrictions
)
@
property
def
BUILDER
(
self
):
return
self
.
_BUILDER
@
classmethod
def
_isvalidsequence
(
cls
,
v
):
"""Check if the input values are valid sequences.
Args:
v: Input sequence.
Returns:
True if the sequence is valid. Valid sequence includes the sequence
type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
is dict or ParamsDict.
"""
if
not
isinstance
(
v
,
cls
.
SEQUENCE_TYPES
):
return
False
return
(
all
(
isinstance
(
e
,
cls
.
IMMUTABLE_TYPES
)
for
e
in
v
)
or
all
(
isinstance
(
e
,
dict
)
for
e
in
v
)
or
all
(
isinstance
(
e
,
params_dict
.
ParamsDict
)
for
e
in
v
))
@
classmethod
def
_import_config
(
cls
,
v
,
subconfig_type
):
"""Returns v with dicts converted to Configs, recursively."""
if
not
issubclass
(
subconfig_type
,
params_dict
.
ParamsDict
):
raise
TypeError
(
'Subconfig_type should be subclass of ParamsDict, found {!r}'
.
format
(
subconfig_type
))
if
isinstance
(
v
,
cls
.
IMMUTABLE_TYPES
):
return
v
elif
isinstance
(
v
,
cls
.
SEQUENCE_TYPES
):
# Only support one layer of sequence.
if
not
cls
.
_isvalidsequence
(
v
):
raise
TypeError
(
'Invalid sequence: only supports single level {!r} of {!r} or '
'dict or ParamsDict found: {!r}'
.
format
(
cls
.
SEQUENCE_TYPES
,
cls
.
IMMUTABLE_TYPES
,
v
))
import_fn
=
functools
.
partial
(
cls
.
_import_config
,
subconfig_type
=
subconfig_type
)
return
type
(
v
)(
map
(
import_fn
,
v
))
elif
isinstance
(
v
,
params_dict
.
ParamsDict
):
# Deepcopy here is a temporary solution for preserving type in nested
# Config object.
return
copy
.
deepcopy
(
v
)
elif
isinstance
(
v
,
dict
):
return
subconfig_type
(
v
)
else
:
raise
TypeError
(
'Unknown type: {!r}'
.
format
(
type
(
v
)))
@
classmethod
def
_export_config
(
cls
,
v
):
"""Returns v with Configs converted to dicts, recursively."""
if
isinstance
(
v
,
cls
.
IMMUTABLE_TYPES
):
return
v
elif
isinstance
(
v
,
cls
.
SEQUENCE_TYPES
):
return
type
(
v
)(
map
(
cls
.
_export_config
,
v
))
elif
isinstance
(
v
,
params_dict
.
ParamsDict
):
return
v
.
as_dict
()
elif
isinstance
(
v
,
dict
):
raise
TypeError
(
'dict value not supported in converting.'
)
else
:
raise
TypeError
(
'Unknown type: {!r}'
.
format
(
type
(
v
)))
@
classmethod
def
_get_subconfig_type
(
cls
,
k
)
->
Type
[
params_dict
.
ParamsDict
]:
"""Get element type by the field name.
Args:
k: the key/name of the field.
Returns:
Config as default. If a type annotation is found for `k`,
1) returns the type of the annotation if it is subtype of ParamsDict;
2) returns the element type if the annotation of `k` is List[SubType]
or Tuple[SubType].
"""
subconfig_type
=
Config
if
k
in
cls
.
__annotations__
:
# Directly Config subtype.
type_annotation
=
cls
.
__annotations__
[
k
]
# pytype: disable=invalid-annotation
if
(
isinstance
(
type_annotation
,
type
)
and
issubclass
(
type_annotation
,
Config
)):
subconfig_type
=
cls
.
__annotations__
[
k
]
# pytype: disable=invalid-annotation
else
:
# Check if the field is a sequence of subtypes.
field_type
=
getattr
(
type_annotation
,
'__origin__'
,
type
(
None
))
if
(
isinstance
(
field_type
,
type
)
and
issubclass
(
field_type
,
cls
.
SEQUENCE_TYPES
)):
element_type
=
getattr
(
type_annotation
,
'__args__'
,
[
type
(
None
)])[
0
]
subconfig_type
=
(
element_type
if
issubclass
(
element_type
,
params_dict
.
ParamsDict
)
else
subconfig_type
)
return
subconfig_type
def
_set
(
self
,
k
,
v
):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
k: key to set.
v: value.
Raises:
RuntimeError
"""
subconfig_type
=
self
.
_get_subconfig_type
(
k
)
def
is_null
(
k
):
if
k
not
in
self
.
__dict__
or
not
self
.
__dict__
[
k
]:
return
True
return
False
if
isinstance
(
v
,
dict
):
if
is_null
(
k
):
# If the key not exist or the value is None, a new Config-family object
# sould be created for the key.
self
.
__dict__
[
k
]
=
subconfig_type
(
v
)
else
:
self
.
__dict__
[
k
].
override
(
v
)
elif
not
is_null
(
k
)
and
isinstance
(
v
,
self
.
SEQUENCE_TYPES
)
and
all
(
[
not
isinstance
(
e
,
self
.
IMMUTABLE_TYPES
)
for
e
in
v
]):
if
len
(
self
.
__dict__
[
k
])
==
len
(
v
):
for
i
in
range
(
len
(
v
)):
self
.
__dict__
[
k
][
i
].
override
(
v
[
i
])
elif
not
all
([
isinstance
(
e
,
self
.
IMMUTABLE_TYPES
)
for
e
in
v
]):
logging
.
warning
(
"The list/tuple don't match the value dictionaries provided. Thus, "
'the list/tuple is determined by the type annotation and '
'values provided. This is error-prone.'
)
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
else
:
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
else
:
self
.
__dict__
[
k
]
=
self
.
_import_config
(
v
,
subconfig_type
)
def
__setattr__
(
self
,
k
,
v
):
if
k
==
'BUILDER'
or
k
==
'_BUILDER'
:
raise
AttributeError
(
'`BUILDER` is a property and `_BUILDER` is the '
'reserved class attribute. We should only assign '
'`_BUILDER` at the class level.'
)
if
k
not
in
self
.
RESERVED_ATTR
:
if
getattr
(
self
,
'_locked'
,
False
):
raise
ValueError
(
'The Config has been locked. '
'No change is allowed.'
)
self
.
_set
(
k
,
v
)
def
_override
(
self
,
override_dict
,
is_strict
=
True
):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
override_dict: dictionary to write to .
is_strict: If True, not allows to add new keys.
Raises:
KeyError: overriding reserved keys or keys not exist (is_strict=True).
"""
for
k
,
v
in
sorted
(
override_dict
.
items
()):
if
k
in
self
.
RESERVED_ATTR
:
raise
KeyError
(
'The key {!r} is internally reserved. '
'Can not be overridden.'
.
format
(
k
))
if
k
not
in
self
.
__dict__
:
if
is_strict
:
raise
KeyError
(
'The key {!r} does not exist in {!r}. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'
.
format
(
k
,
type
(
self
)))
else
:
self
.
_set
(
k
,
v
)
else
:
if
isinstance
(
v
,
dict
)
and
self
.
__dict__
[
k
]:
self
.
__dict__
[
k
].
_override
(
v
,
is_strict
)
# pylint: disable=protected-access
elif
isinstance
(
v
,
params_dict
.
ParamsDict
)
and
self
.
__dict__
[
k
]:
self
.
__dict__
[
k
].
_override
(
v
.
as_dict
(),
is_strict
)
# pylint: disable=protected-access
else
:
self
.
_set
(
k
,
v
)
def
as_dict
(
self
):
"""Returns a dict representation of params_dict.ParamsDict.
For the nested params_dict.ParamsDict, a nested dict will be returned.
"""
return
{
k
:
self
.
_export_config
(
v
)
for
k
,
v
in
self
.
__dict__
.
items
()
if
k
not
in
self
.
RESERVED_ATTR
}
def
replace
(
self
,
**
kwargs
):
"""Overrides/returns a unlocked copy with the current config unchanged."""
# pylint: disable=protected-access
params
=
copy
.
deepcopy
(
self
)
params
.
_locked
=
False
params
.
_override
(
kwargs
,
is_strict
=
True
)
# pylint: enable=protected-access
return
params
@
classmethod
def
from_yaml
(
cls
,
file_path
:
str
):
# Note: This only works if the Config has all default values.
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'r'
)
as
f
:
loaded
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
config
=
cls
()
config
.
override
(
loaded
)
return
config
@
classmethod
def
from_json
(
cls
,
file_path
:
str
):
"""Wrapper for `from_yaml`."""
return
cls
.
from_yaml
(
file_path
)
@
classmethod
def
from_args
(
cls
,
*
args
,
**
kwargs
):
"""Builds a config from the given list of arguments."""
attributes
=
list
(
cls
.
__annotations__
.
keys
())
default_params
=
{
a
:
p
for
a
,
p
in
zip
(
attributes
,
args
)}
default_params
.
update
(
kwargs
)
return
cls
(
default_params
=
default_params
)
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/base_config_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pprint
from
typing
import
List
,
Tuple
from
absl.testing
import
parameterized
import
dataclasses
import
tensorflow
as
tf
from
official.modeling.hyperparams
import
base_config
@
dataclasses
.
dataclass
class
DumpConfig1
(
base_config
.
Config
):
a
:
int
=
1
b
:
str
=
'text'
@
dataclasses
.
dataclass
class
DumpConfig2
(
base_config
.
Config
):
c
:
int
=
2
d
:
str
=
'text'
e
:
DumpConfig1
=
DumpConfig1
()
@
dataclasses
.
dataclass
class
DumpConfig3
(
DumpConfig2
):
f
:
int
=
2
g
:
str
=
'text'
h
:
List
[
DumpConfig1
]
=
dataclasses
.
field
(
default_factory
=
lambda
:
[
DumpConfig1
(),
DumpConfig1
()])
g
:
Tuple
[
DumpConfig1
,
...]
=
(
DumpConfig1
(),)
@
dataclasses
.
dataclass
class
DumpConfig4
(
DumpConfig2
):
x
:
int
=
3
@
dataclasses
.
dataclass
class
DummyConfig5
(
base_config
.
Config
):
y
:
Tuple
[
DumpConfig2
,
...]
=
(
DumpConfig2
(),
DumpConfig4
())
z
:
Tuple
[
str
]
=
(
'a'
,)
class
BaseConfigTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
def
assertHasSameTypes
(
self
,
c
,
d
,
msg
=
''
):
"""Checks if a Config has the same structure as a given dict.
Args:
c: the Config object to be check.
d: the reference dict object.
msg: The error message to show when type mismatched.
"""
# Make sure d is not a Config. Assume d is either
# dictionary or primitive type and c is the Config or primitive types.
self
.
assertNotIsInstance
(
d
,
base_config
.
Config
)
if
isinstance
(
d
,
base_config
.
Config
.
IMMUTABLE_TYPES
):
self
.
assertEqual
(
pprint
.
pformat
(
c
),
pprint
.
pformat
(
d
),
msg
=
msg
)
elif
isinstance
(
d
,
base_config
.
Config
.
SEQUENCE_TYPES
):
self
.
assertEqual
(
type
(
c
),
type
(
d
),
msg
=
msg
)
for
i
,
v
in
enumerate
(
d
):
self
.
assertHasSameTypes
(
c
[
i
],
v
,
msg
=
'{}[{!r}]'
.
format
(
msg
,
i
))
elif
isinstance
(
d
,
dict
):
self
.
assertIsInstance
(
c
,
base_config
.
Config
,
msg
=
msg
)
for
k
,
v
in
sorted
(
d
.
items
()):
self
.
assertHasSameTypes
(
getattr
(
c
,
k
),
v
,
msg
=
'{}[{!r}]'
.
format
(
msg
,
k
))
else
:
raise
TypeError
(
'Unknown type: %r'
%
type
(
d
))
def
assertImportExport
(
self
,
v
):
config
=
base_config
.
Config
({
'key'
:
v
})
back
=
config
.
as_dict
()[
'key'
]
self
.
assertEqual
(
pprint
.
pformat
(
back
),
pprint
.
pformat
(
v
))
self
.
assertHasSameTypes
(
config
.
key
,
v
,
msg
=
'=%s v'
%
pprint
.
pformat
(
v
))
def
test_invalid_keys
(
self
):
params
=
base_config
.
Config
()
with
self
.
assertRaises
(
AttributeError
):
_
=
params
.
a
def
test_cls
(
self
):
params
=
base_config
.
Config
()
with
self
.
assertRaisesRegex
(
AttributeError
,
'`BUILDER` is a property and `_BUILDER` is the reserved'
):
params
.
BUILDER
=
DumpConfig2
with
self
.
assertRaisesRegex
(
AttributeError
,
'`BUILDER` is a property and `_BUILDER` is the reserved'
):
params
.
_BUILDER
=
DumpConfig2
base_config
.
bind
(
DumpConfig1
)(
DumpConfig2
)
params
=
DumpConfig1
()
self
.
assertEqual
(
params
.
BUILDER
,
DumpConfig2
)
with
self
.
assertRaisesRegex
(
ValueError
,
'Inside a program, we should not bind'
):
base_config
.
bind
(
DumpConfig1
)(
DumpConfig2
)
def
_test
():
return
'test'
base_config
.
bind
(
DumpConfig2
)(
_test
)
params
=
DumpConfig2
()
self
.
assertEqual
(
params
.
BUILDER
(),
'test'
)
def
test_nested_config_types
(
self
):
config
=
DumpConfig3
()
self
.
assertIsInstance
(
config
.
e
,
DumpConfig1
)
self
.
assertIsInstance
(
config
.
h
[
0
],
DumpConfig1
)
self
.
assertIsInstance
(
config
.
h
[
1
],
DumpConfig1
)
self
.
assertIsInstance
(
config
.
g
[
0
],
DumpConfig1
)
config
.
override
({
'e'
:
{
'a'
:
2
,
'b'
:
'new text'
}})
self
.
assertIsInstance
(
config
.
e
,
DumpConfig1
)
self
.
assertEqual
(
config
.
e
.
a
,
2
)
self
.
assertEqual
(
config
.
e
.
b
,
'new text'
)
config
.
override
({
'h'
:
[{
'a'
:
3
,
'b'
:
'new text 2'
}]})
self
.
assertIsInstance
(
config
.
h
[
0
],
DumpConfig1
)
self
.
assertLen
(
config
.
h
,
1
)
self
.
assertEqual
(
config
.
h
[
0
].
a
,
3
)
self
.
assertEqual
(
config
.
h
[
0
].
b
,
'new text 2'
)
config
.
override
({
'g'
:
[{
'a'
:
4
,
'b'
:
'new text 3'
}]})
self
.
assertIsInstance
(
config
.
g
[
0
],
DumpConfig1
)
self
.
assertLen
(
config
.
g
,
1
)
self
.
assertEqual
(
config
.
g
[
0
].
a
,
4
)
self
.
assertEqual
(
config
.
g
[
0
].
b
,
'new text 3'
)
def
test_replace
(
self
):
config
=
DumpConfig2
()
new_config
=
config
.
replace
(
e
=
{
'a'
:
2
})
self
.
assertEqual
(
new_config
.
e
.
a
,
2
)
self
.
assertIsInstance
(
new_config
.
e
,
DumpConfig1
)
config
=
DumpConfig2
(
e
=
DumpConfig2
())
new_config
=
config
.
replace
(
e
=
{
'c'
:
4
})
self
.
assertEqual
(
new_config
.
e
.
c
,
4
)
self
.
assertIsInstance
(
new_config
.
e
,
DumpConfig2
)
config
=
DumpConfig3
()
new_config
=
config
.
replace
(
g
=
[{
'a'
:
4
,
'b'
:
'new text 3'
}])
self
.
assertIsInstance
(
new_config
.
g
[
0
],
DumpConfig1
)
self
.
assertEqual
(
new_config
.
g
[
0
].
a
,
4
)
@
parameterized
.
parameters
(
(
'_locked'
,
"The key '_locked' is internally reserved."
),
(
'_restrictions'
,
"The key '_restrictions' is internally reserved."
),
(
'aa'
,
"The key 'aa' does not exist."
),
)
def
test_key_error
(
self
,
key
,
msg
):
params
=
base_config
.
Config
()
with
self
.
assertRaisesRegex
(
KeyError
,
msg
):
params
.
override
({
key
:
True
})
@
parameterized
.
parameters
(
(
'str data'
,),
(
123
,),
(
1.23
,),
(
None
,),
([
'str'
,
1
,
2.3
,
None
],),
((
'str'
,
1
,
2.3
,
None
),),
)
def
test_import_export_immutable_types
(
self
,
v
):
self
.
assertImportExport
(
v
)
out
=
base_config
.
Config
({
'key'
:
v
})
self
.
assertEqual
(
pprint
.
pformat
(
v
),
pprint
.
pformat
(
out
.
key
))
def
test_override_is_strict_true
(
self
):
params
=
base_config
.
Config
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
'cc'
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c1'
:
'ccc'
}},
is_strict
=
True
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c1
,
'ccc'
)
with
self
.
assertRaises
(
KeyError
):
params
.
override
({
'd'
:
'ddd'
},
is_strict
=
True
)
with
self
.
assertRaises
(
KeyError
):
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
config
=
base_config
.
Config
({
'key'
:
[{
'a'
:
42
}]})
with
self
.
assertRaisesRegex
(
KeyError
,
"The key 'b' does not exist"
):
config
.
override
({
'key'
:
[{
'b'
:
43
}]})
@
parameterized
.
parameters
(
(
lambda
x
:
x
,
'Unknown type'
),
(
object
(),
'Unknown type'
),
(
set
(),
'Unknown type'
),
(
frozenset
(),
'Unknown type'
),
)
def
test_import_unsupport_types
(
self
,
v
,
msg
):
with
self
.
assertRaisesRegex
(
TypeError
,
msg
):
_
=
base_config
.
Config
({
'key'
:
v
})
@
parameterized
.
parameters
(
({
'a'
:
[{
'b'
:
2
,
},
{
'c'
:
3
,
}]
},),
({
'c'
:
[{
'f'
:
1.1
,
},
{
'h'
:
[
1
,
2
],
}]
},),
(({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
,
}
},),),
)
def
test_import_export_nested_structure
(
self
,
d
):
self
.
assertImportExport
(
d
)
@
parameterized
.
parameters
(
([{
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
}],),
(({
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
},),),
)
def
test_import_export_nested_sequences
(
self
,
v
):
self
.
assertImportExport
(
v
)
@
parameterized
.
parameters
(
([([{}],)],),
([[
'str'
,
1
,
2.3
,
None
]],),
(((
'str'
,
1
,
2.3
,
None
),),),
([
(
'str'
,
1
,
2.3
,
None
),
],),
([
(
'str'
,
1
,
2.3
,
None
),
],),
([[{
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
}]],),
([[[{
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
}]]],),
((({
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
},),),),
(((({
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
},),),),),
([({
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
},)],),
(([{
'a'
:
42
,
'b'
:
'hello'
,
'c'
:
1.2
}],),),
)
def
test_import_export_unsupport_sequence
(
self
,
v
):
with
self
.
assertRaisesRegex
(
TypeError
,
'Invalid sequence: only supports single level'
):
_
=
base_config
.
Config
({
'key'
:
v
})
def
test_construct_subtype
(
self
):
pass
def
test_import_config
(
self
):
params
=
base_config
.
Config
({
'a'
:
[{
'b'
:
2
},
{
'c'
:
{
'd'
:
3
}}]})
self
.
assertLen
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
a
[
0
].
b
,
2
)
self
.
assertEqual
(
type
(
params
.
a
[
0
]),
base_config
.
Config
)
self
.
assertEqual
(
pprint
.
pformat
(
params
.
a
[
0
].
b
),
'2'
)
self
.
assertEqual
(
type
(
params
.
a
[
1
]),
base_config
.
Config
)
self
.
assertEqual
(
type
(
params
.
a
[
1
].
c
),
base_config
.
Config
)
self
.
assertEqual
(
pprint
.
pformat
(
params
.
a
[
1
].
c
.
d
),
'3'
)
def
test_override
(
self
):
params
=
base_config
.
Config
({
'a'
:
[{
'b'
:
2
},
{
'c'
:
{
'd'
:
3
}}]})
params
.
override
({
'a'
:
[{
'b'
:
4
},
{
'c'
:
{
'd'
:
5
}}]},
is_strict
=
False
)
self
.
assertEqual
(
type
(
params
.
a
),
list
)
self
.
assertEqual
(
type
(
params
.
a
[
0
]),
base_config
.
Config
)
self
.
assertEqual
(
pprint
.
pformat
(
params
.
a
[
0
].
b
),
'4'
)
self
.
assertEqual
(
type
(
params
.
a
[
1
]),
base_config
.
Config
)
self
.
assertEqual
(
type
(
params
.
a
[
1
].
c
),
base_config
.
Config
)
self
.
assertEqual
(
pprint
.
pformat
(
params
.
a
[
1
].
c
.
d
),
'5'
)
@
parameterized
.
parameters
(
([{}],),
(({},),),
)
def
test_config_vs_params_dict
(
self
,
v
):
d
=
{
'key'
:
v
}
self
.
assertEqual
(
type
(
base_config
.
Config
(
d
).
key
[
0
]),
base_config
.
Config
)
self
.
assertEqual
(
type
(
base_config
.
params_dict
.
ParamsDict
(
d
).
key
[
0
]),
dict
)
def
test_ppformat
(
self
):
self
.
assertEqual
(
pprint
.
pformat
([
's'
,
1
,
1.0
,
True
,
None
,
{},
[],
(),
{
(
2
,):
(
3
,
[
4
],
{
6
:
7
,
}),
8
:
9
,
}
]),
"['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]"
)
def
test_with_restrictions
(
self
):
restrictions
=
[
'e.a<c'
]
config
=
DumpConfig2
(
restrictions
=
restrictions
)
config
.
validate
()
def
test_nested_tuple
(
self
):
config
=
DummyConfig5
()
config
.
override
({
'y'
:
[{
'c'
:
4
,
'd'
:
'new text 3'
,
'e'
:
{
'a'
:
2
}
},
{
'c'
:
0
,
'd'
:
'new text 3'
,
'e'
:
{
'a'
:
2
}
}],
'z'
:
[
'a'
,
'b'
,
'c'
],
})
self
.
assertEqual
(
config
.
y
[
0
].
c
,
4
)
self
.
assertEqual
(
config
.
y
[
1
].
c
,
0
)
self
.
assertIsInstance
(
config
.
y
[
0
],
DumpConfig2
)
self
.
assertIsInstance
(
config
.
y
[
1
],
DumpConfig4
)
self
.
assertSameElements
(
config
.
z
,
[
'a'
,
'b'
,
'c'
])
def
test_override_by_empty_sequence
(
self
):
config
=
DummyConfig5
()
config
.
override
({
'y'
:
[],
'z'
:
(),
},
is_strict
=
True
)
self
.
assertEmpty
(
config
.
y
)
self
.
assertEmpty
(
config
.
z
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/modeling/hyperparams/oneof.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Config class that supports oneof functionality."""
from
typing
import
Optional
import
dataclasses
from
official.modeling.hyperparams
import
base_config
@
dataclasses
.
dataclass
class
OneOfConfig
(
base_config
.
Config
):
"""Configuration for configs with one of feature.
Attributes:
type: 'str', name of the field to select.
"""
type
:
Optional
[
str
]
=
None
def
as_dict
(
self
):
"""Returns a dict representation of OneOfConfig.
For the nested base_config.Config, a nested dict will be returned.
"""
if
self
.
type
is
None
:
return
{
'type'
:
None
}
elif
self
.
__dict__
[
'type'
]
not
in
self
.
__dict__
:
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
__dict__
[
'type'
]))
else
:
chosen_type
=
self
.
type
chosen_value
=
self
.
__dict__
[
chosen_type
]
return
{
'type'
:
self
.
type
,
chosen_type
:
self
.
_export_config
(
chosen_value
)}
def
get
(
self
):
"""Returns selected config based on the value of type.
If type is not set (None), None is returned.
"""
chosen_type
=
self
.
type
if
chosen_type
is
None
:
return
None
if
chosen_type
not
in
self
.
__dict__
:
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
type
))
return
self
.
__dict__
[
chosen_type
]
Prev
1
…
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment