Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
88253ce5
Commit
88253ce5
authored
Aug 12, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Aug 12, 2020
Browse files
Internal change
PiperOrigin-RevId: 326286926
parent
52371ffe
Changes
205
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
324 additions
and
259 deletions
+324
-259
official/core/base_task.py
official/core/base_task.py
+12
-9
official/core/base_trainer.py
official/core/base_trainer.py
+7
-5
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+13
-10
official/modeling/hyperparams/base_config.py
official/modeling/hyperparams/base_config.py
+5
-4
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+10
-11
official/modeling/hyperparams/oneof.py
official/modeling/hyperparams/oneof.py
+4
-8
official/modeling/hyperparams/oneof_test.py
official/modeling/hyperparams/oneof_test.py
+12
-6
official/modeling/hyperparams/params_dict.py
official/modeling/hyperparams/params_dict.py
+33
-23
official/modeling/hyperparams/params_dict_test.py
official/modeling/hyperparams/params_dict_test.py
+113
-41
official/modeling/optimization/configs/learning_rate_config.py
...ial/modeling/optimization/configs/learning_rate_config.py
+18
-25
official/modeling/optimization/configs/optimization_config_test.py
...modeling/optimization/configs/optimization_config_test.py
+3
-4
official/modeling/optimization/configs/optimizer_config.py
official/modeling/optimization/configs/optimizer_config.py
+8
-9
official/modeling/optimization/optimizer_factory.py
official/modeling/optimization/optimizer_factory.py
+3
-3
official/modeling/optimization/optimizer_factory_test.py
official/modeling/optimization/optimizer_factory_test.py
+58
-62
official/modeling/performance.py
official/modeling/performance.py
+3
-4
official/modeling/tf_utils.py
official/modeling/tf_utils.py
+2
-4
official/modeling/training/distributed_executor.py
official/modeling/training/distributed_executor.py
+17
-27
official/nlp/albert/configs.py
official/nlp/albert/configs.py
+1
-4
official/nlp/albert/export_albert_tfhub.py
official/nlp/albert/export_albert_tfhub.py
+1
-0
official/nlp/albert/run_classifier.py
official/nlp/albert/run_classifier.py
+1
-0
No files found.
official/core/base_task.py
View file @
88253ce5
...
...
@@ -171,26 +171,30 @@ class Task(tf.Module):
return
[]
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
"""Process and update metrics. Called when using custom training loop API.
"""Process and update metrics.
Called when using custom training loop API.
Args:
metrics: a nested structure of metrics objects.
The return of function
self.build_metrics.
metrics: a nested structure of metrics objects.
The return of function
self.build_metrics.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
For example,
output of the keras model built by self.build_model.
model_outputs: a tensor or a nested structure of tensors.
For example,
output of the keras model built by self.build_model.
"""
for
metric
in
metrics
:
metric
.
update_state
(
labels
,
model_outputs
)
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
"""Process and update compiled_metrics. call when using compile/fit API.
"""Process and update compiled_metrics.
call when using compile/fit API.
Args:
compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
For example,
output of the keras model built by self.build_model.
model_outputs: a tensor or a nested structure of tensors.
For example,
output of the keras model built by self.build_model.
"""
compiled_metrics
.
update_state
(
labels
,
model_outputs
)
...
...
@@ -297,4 +301,3 @@ class Task(tf.Module):
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
"""Optional reduce of aggregated logs over validation steps."""
return
{}
official/core/base_trainer.py
View file @
88253ce5
...
...
@@ -19,6 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
import
gin
import
orbit
import
tensorflow
as
tf
...
...
@@ -28,7 +29,6 @@ from official.modeling import optimization
from
official.modeling
import
performance
from
official.modeling.hyperparams
import
config_definitions
ExperimentConfig
=
config_definitions
.
ExperimentConfig
...
...
@@ -52,8 +52,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
model: tf.keras.Model instance. If provided, it will be used instead
of
building model using task.build_model(). Default to None.
model: tf.keras.Model instance. If provided, it will be used instead
of
building model using task.build_model(). Default to None.
optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will
used instead of the optimizer from config. Default to None.
"""
...
...
@@ -90,8 +90,10 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
else
:
checkpoint_items
=
{}
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
model
=
self
.
model
,
optimizer
=
self
.
optimizer
,
**
checkpoint_items
)
global_step
=
self
.
global_step
,
model
=
self
.
model
,
optimizer
=
self
.
optimizer
,
**
checkpoint_items
)
self
.
_train_loss
=
tf
.
keras
.
metrics
.
Mean
(
'training_loss'
,
dtype
=
tf
.
float32
)
self
.
_validation_loss
=
tf
.
keras
.
metrics
.
Mean
(
...
...
official/core/base_trainer_test.py
View file @
88253ce5
...
...
@@ -15,6 +15,7 @@
# ==============================================================================
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
from
absl.testing
import
parameterized
import
tensorflow
as
tf
...
...
@@ -42,13 +43,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
super
().
setUp
()
self
.
_config
=
cfg
.
ExperimentConfig
(
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
cfg
.
OptimizationConfig
(
{
'optimizer'
:
{
optimizer_config
=
cfg
.
OptimizationConfig
(
{
'optimizer'
:
{
'type'
:
'sgd'
},
'learning_rate'
:
{
'type'
:
'constant'
}})))
}
})))
def
create_test_trainer
(
self
):
task
=
mock_task
.
MockTask
()
...
...
@@ -81,13 +83,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
runtime
=
cfg
.
RuntimeConfig
(
mixed_precision_dtype
=
mixed_precision_dtype
,
loss_scale
=
loss_scale
),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
cfg
.
OptimizationConfig
(
{
'optimizer'
:
{
optimizer_config
=
cfg
.
OptimizationConfig
(
{
'optimizer'
:
{
'type'
:
'sgd'
},
'learning_rate'
:
{
'type'
:
'constant'
}})))
}
})))
task
=
mock_task
.
MockTask
()
trainer
=
trainer_lib
.
Trainer
(
config
,
task
)
if
mixed_precision_dtype
!=
'float16'
:
...
...
official/modeling/hyperparams/base_config.py
View file @
88253ce5
...
...
@@ -136,7 +136,8 @@ class Config(params_dict.ParamsDict):
return
subconfig_type
def
__post_init__
(
self
,
default_params
,
restrictions
,
*
args
,
**
kwargs
):
super
().
__init__
(
default_params
=
default_params
,
super
().
__init__
(
default_params
=
default_params
,
restrictions
=
restrictions
,
*
args
,
**
kwargs
)
...
...
official/modeling/hyperparams/config_definitions.py
View file @
88253ce5
...
...
@@ -55,14 +55,14 @@ class DataConfig(base_config.Config):
exhaust all the examples in the dataset.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
the
returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default,
the
returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped
for
decoding when loading dataset from TFDS. Use comma to separate
multiple
features. The main use case is to skip the image/video decoding
for better
performance.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
the
returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default,
the
returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped
for
decoding when loading dataset from TFDS. Use comma to separate
multiple
features. The main use case is to skip the image/video decoding
for better
performance.
"""
input_path
:
str
=
""
tfds_name
:
str
=
""
...
...
@@ -177,8 +177,8 @@ class TrainerConfig(base_config.Config):
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely.
This
is only used continuous_train_and_eval and continuous_eval modes.
checkpoints, if set to None, continuous eval will wait indefinitely.
This
is only used continuous_train_and_eval and continuous_eval modes.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
...
...
@@ -217,4 +217,3 @@ class ExperimentConfig(base_config.Config):
task
:
TaskConfig
=
TaskConfig
()
trainer
:
TrainerConfig
=
TrainerConfig
()
runtime
:
RuntimeConfig
=
RuntimeConfig
()
official/modeling/hyperparams/oneof.py
View file @
88253ce5
...
...
@@ -38,15 +38,12 @@ class OneOfConfig(base_config.Config):
if
self
.
type
is
None
:
return
{
'type'
:
None
}
elif
self
.
__dict__
[
'type'
]
not
in
self
.
__dict__
:
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
__dict__
[
'type'
]))
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
__dict__
[
'type'
]))
else
:
chosen_type
=
self
.
type
chosen_value
=
self
.
__dict__
[
chosen_type
]
return
{
'type'
:
self
.
type
,
chosen_type
:
self
.
_export_config
(
chosen_value
)
}
return
{
'type'
:
self
.
type
,
chosen_type
:
self
.
_export_config
(
chosen_value
)}
def
get
(
self
):
"""Returns selected config based on the value of type.
...
...
@@ -57,6 +54,5 @@ class OneOfConfig(base_config.Config):
if
chosen_type
is
None
:
return
None
if
chosen_type
not
in
self
.
__dict__
:
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
type
))
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
type
))
return
self
.
__dict__
[
chosen_type
]
official/modeling/hyperparams/oneof_test.py
View file @
88253ce5
...
...
@@ -48,11 +48,17 @@ class Network(base_config.Config):
class
OneOfTest
(
tf
.
test
.
TestCase
):
def
test_to_dict
(
self
):
network_params
=
{
'backbone'
:
{
'type'
:
'resnet'
,
'resnet'
:
{
'model_depth'
:
50
}
network_params
=
{
'backbone'
:
{
'type'
:
'resnet'
,
'resnet'
:
{
'model_depth'
:
50
}
},
'output_layer'
:
{
'type'
:
'single'
,
'single'
:
1000
}
'output_layer'
:
{
'type'
:
'single'
,
'single'
:
1000
}
}
network_config
=
Network
(
network_params
)
self
.
assertEqual
(
network_config
.
as_dict
(),
network_params
)
...
...
official/modeling/hyperparams/params_dict.py
View file @
88253ce5
...
...
@@ -30,7 +30,8 @@ import yaml
# key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE
=
re
.
compile
(
r
"""
_PARAM_RE
=
re
.
compile
(
r
"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s*
((?P<val>\'(.*?)\' # single quote
...
...
@@ -138,8 +139,8 @@ class ParamsDict(object):
ValueError: if the ParamsDict instance has been locked.
"""
if
k
in
ParamsDict
.
RESERVED_ATTR
:
raise
AttributeError
(
'The key `{}` is reserved. No change is allowes. '
.
format
(
k
))
raise
AttributeError
(
'The key `{}` is reserved. No change is allowes. '
.
format
(
k
))
if
k
not
in
self
.
__dict__
.
keys
():
raise
AttributeError
(
'The key `{}` does not exist. '
.
format
(
k
))
if
self
.
_locked
:
...
...
@@ -150,13 +151,13 @@ class ParamsDict(object):
"""Override the ParamsDict with a set of given params.
Args:
override_params: a dict or a ParamsDict specifying the parameters to
be
overridden.
override_params: a dict or a ParamsDict specifying the parameters to
be
overridden.
is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict.
If
False, keys in `override_params` can be different from what is
currently
defined in the ParamsDict. In this case, the ParamsDict will
be extended
to include the new keys.
True, keys in `override_params` must be present in the ParamsDict.
If
False, keys in `override_params` can be different from what is
currently
defined in the ParamsDict. In this case, the ParamsDict will
be extended
to include the new keys.
"""
if
self
.
_locked
:
raise
ValueError
(
'The ParamsDict has been locked. No change is allowed.'
)
...
...
@@ -240,6 +241,7 @@ class ParamsDict(object):
(2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported.
"""
def
_get_kv
(
dotted_string
,
params_dict
):
"""Get keys and values indicated by dotted_string."""
if
_CONST_VALUE_RE
.
match
(
dotted_string
)
is
not
None
:
...
...
@@ -270,38 +272,44 @@ class ParamsDict(object):
tokens
=
restriction
.
split
(
'=='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
!=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'!='
in
restriction
:
tokens
=
restriction
.
split
(
'!='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
==
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<'
in
restriction
:
tokens
=
restriction
.
split
(
'<'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<='
in
restriction
:
tokens
=
restriction
.
split
(
'<='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>'
in
restriction
:
tokens
=
restriction
.
split
(
'>'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>='
in
restriction
:
tokens
=
restriction
.
split
(
'>='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
else
:
raise
ValueError
(
'Unsupported relation in restriction.'
)
...
...
@@ -316,10 +324,12 @@ def read_yaml_to_params_dict(file_path):
def
save_params_dict_to_yaml
(
params
,
file_path
):
"""Saves the input ParamsDict to a YAML file."""
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'w'
)
as
f
:
def
_my_list_rep
(
dumper
,
data
):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return
dumper
.
represent_sequence
(
u
'tag:yaml.org,2002:seq'
,
data
,
flow_style
=
True
)
yaml
.
add_representer
(
list
,
_my_list_rep
)
yaml
.
dump
(
params
.
as_dict
(),
f
,
default_flow_style
=
False
)
...
...
@@ -408,8 +418,8 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
Args:
params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
path to
a YAML file specifying the parameters to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
path to
a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not.
Returns:
...
...
official/modeling/hyperparams/params_dict_test.py
View file @
88253ce5
...
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for params_dict.py."""
import
os
...
...
@@ -56,8 +55,7 @@ class ParamsDictTest(tf.test.TestCase):
def
test_setattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
params
.
c
=
'ccc'
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
...
...
@@ -65,16 +63,22 @@ class ParamsDictTest(tf.test.TestCase):
def
test_getattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
c
,
None
)
def
test_delattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
,
'd'
:
{
'd1'
:
1
,
'd2'
:
10
}},
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
,
'd'
:
{
'd1'
:
1
,
'd2'
:
10
}
},
is_strict
=
False
)
del
params
.
c
self
.
assertEqual
(
params
.
a
,
'aa'
)
...
...
@@ -87,22 +91,26 @@ class ParamsDictTest(tf.test.TestCase):
def
test_contains
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
(
{
'a'
:
'aa'
},
is_strict
=
False
)
params
.
override
({
'a'
:
'aa'
},
is_strict
=
False
)
self
.
assertIn
(
'a'
,
params
)
self
.
assertNotIn
(
'b'
,
params
)
def
test_get
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
(
{
'a'
:
'aa'
},
is_strict
=
False
)
params
.
override
({
'a'
:
'aa'
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
get
(
'a'
),
'aa'
)
self
.
assertEqual
(
params
.
get
(
'b'
,
2
),
2
)
self
.
assertEqual
(
params
.
get
(
'b'
),
None
)
def
test_override_is_strict_true
(
self
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
'cc'
,
'c2'
:
20
}})
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
'cc'
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c1'
:
'ccc'
}},
is_strict
=
True
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c1
,
'ccc'
)
...
...
@@ -112,8 +120,14 @@ class ParamsDictTest(tf.test.TestCase):
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
def
test_override_is_strict_false
(
self
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}})
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c3'
:
3000
}},
is_strict
=
False
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c3
,
3000
)
...
...
@@ -123,8 +137,14 @@ class ParamsDictTest(tf.test.TestCase):
self
.
assertEqual
(
params
.
c
.
c4
,
4444
)
def
test_as_dict
(
self
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}})
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
params_d
=
params
.
as_dict
()
self
.
assertEqual
(
params_d
[
'a'
],
'aa'
)
self
.
assertEqual
(
params_d
[
'b'
],
2
)
...
...
@@ -134,21 +154,25 @@ class ParamsDictTest(tf.test.TestCase):
def
test_validate
(
self
):
# Raise error due to the unknown parameter.
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
1
,
'b'
:
{
'a'
:
11
}},
[
'a == c'
])
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'a'
:
11
}},
[
'a == c'
])
# OK to check equality of two nested dicts.
params
=
params_dict
.
ParamsDict
(
{
'a'
:
1
,
'b'
:
{
'a'
:
10
},
'c'
:
{
'a'
:
10
}},
[
'b == c'
])
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'a'
:
10
},
'c'
:
{
'a'
:
10
}
},
[
'b == c'
])
# Raise error due to inconsistency
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
1
,
'c'
:
{
'a'
:
10
}},
[
'a == c.a'
])
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'c'
:
{
'a'
:
10
}},
[
'a == c.a'
])
# Valid rule.
params
=
params_dict
.
ParamsDict
(
{
'a'
:
1
,
'c'
:
{
'a'
:
1
}},
[
'a == c.a'
])
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'c'
:
{
'a'
:
1
}},
[
'a == c.a'
])
# Overridding violates the existing rule, raise error upon validate.
params
.
override
({
'a'
:
11
})
...
...
@@ -156,12 +180,20 @@ class ParamsDictTest(tf.test.TestCase):
params
.
validate
()
# Valid restrictions with constant.
params
=
params_dict
.
ParamsDict
(
{
'a'
:
None
,
'c'
:
{
'a'
:
1
}},
[
'a == None'
,
'c.a == 1'
])
params
=
params_dict
.
ParamsDict
({
'a'
:
None
,
'c'
:
{
'a'
:
1
}
},
[
'a == None'
,
'c.a == 1'
])
params
.
validate
()
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
4
,
'c'
:
{
'a'
:
1
}},
[
'a == None'
,
'c.a == 1'
])
params
=
params_dict
.
ParamsDict
({
'a'
:
4
,
'c'
:
{
'a'
:
1
}
},
[
'a == None'
,
'c.a == 1'
])
class
ParamsDictIOTest
(
tf
.
test
.
TestCase
):
...
...
@@ -173,8 +205,14 @@ class ParamsDictIOTest(tf.test.TestCase):
return
temp_file
def
test_save_params_dict_to_yaml
(
self
):
params
=
params_dict
.
ParamsDict
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}})
params
=
params_dict
.
ParamsDict
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
output_yaml_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'params.yaml'
)
params_dict
.
save_params_dict_to_yaml
(
params
,
output_yaml_file
)
...
...
@@ -203,7 +241,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_dict
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_dict
=
{
'b'
:
5.2
,
'c'
:
[
30
,
40
]}
params
=
params_dict
.
override_params_dict
(
params
,
override_dict
,
is_strict
=
True
)
...
...
@@ -215,7 +258,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_yaml_string
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_yaml_string
=
"'b': 5.2
\n
'c': [30, 40]"
params
=
params_dict
.
override_params_dict
(
params
,
override_yaml_string
,
is_strict
=
True
)
...
...
@@ -227,8 +275,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_json_string
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}},
'e'
:
False
})
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],
},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}
},
'e'
:
False
})
override_json_string
=
"{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params
=
params_dict
.
override_params_dict
(
params
,
override_json_string
,
is_strict
=
True
)
...
...
@@ -240,8 +298,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_csv_string
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}},
'e'
:
False
})
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],
},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}
},
'e'
:
False
})
override_csv_string
=
"b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params
=
params_dict
.
override_params_dict
(
params
,
override_csv_string
,
is_strict
=
True
)
...
...
@@ -253,7 +321,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_yaml_file
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_yaml_file
=
self
.
write_temp_file
(
'params.yaml'
,
r
"""
b: 5.2
...
...
@@ -321,8 +394,7 @@ class IOTest(tf.test.TestCase):
def
test_csv_str_load_unsupported_datatypes
(
self
):
csv_str
=
'a=[[1,2,3],[4,5,6]]'
self
.
assertRaises
(
ValueError
,
params_dict
.
nested_csv_str_to_json_str
,
self
.
assertRaises
(
ValueError
,
params_dict
.
nested_csv_str_to_json_str
,
csv_str
)
def
test_csv_str_to_json_str_spacing
(
self
):
...
...
official/modeling/optimization/configs/learning_rate_config.py
View file @
88253ce5
...
...
@@ -50,16 +50,13 @@ class StepwiseLrConfig(base_config.Config):
Attributes:
name: The name of the learning rate schedule. Defaults to PiecewiseConstant.
boundaries: A list of ints of strictly increasing entries.
Defaults to None.
boundaries: A list of ints of strictly increasing entries. Defaults to None.
values: A list of floats that specifies the values for the intervals defined
by `boundaries`. It should have one more element than `boundaries`.
The learning rate is computed as follows:
[0, boundaries[0]] -> values[0]
[boundaries[0], boundaries[1]] -> values[1]
[boundaries[n-1], boundaries[n]] -> values[n]
[boundaries[n], end] -> values[n+1]
Defaults to None.
The learning rate is computed as follows: [0, boundaries[0]] ->
values[0] [boundaries[0], boundaries[1]] -> values[1]
[boundaries[n-1], boundaries[n]] -> values[n] [boundaries[n],
end] -> values[n+1] Defaults to None.
"""
name
:
str
=
'PiecewiseConstantDecay'
boundaries
:
Optional
[
List
[
int
]]
=
None
...
...
@@ -74,10 +71,9 @@ class ExponentialLrConfig(base_config.Config):
Attributes:
name: The name of the learning rate schedule. Defaults to ExponentialDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to
None.
decay_steps: A positive integer that is used for decay computation.
Defaults to None.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
decay_steps: A positive integer that is used for decay computation. Defaults
to None.
decay_rate: A float. Defaults to None.
staircase: A boolean, if true, learning rate is decreased at discreate
intervals. Defaults to False.
...
...
@@ -97,10 +93,9 @@ class PolynomialLrConfig(base_config.Config):
Attributes:
name: The name of the learning rate schedule. Defaults to PolynomialDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to
None.
decay_steps: A positive integer that is used for decay computation.
Defaults to None.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
decay_steps: A positive integer that is used for decay computation. Defaults
to None.
end_learning_rate: A float. The minimal end learning rate.
power: A float. The power of the polynomial. Defaults to linear, 1.0.
cycle: A boolean, whether or not it should cycle beyond decay_steps.
...
...
@@ -123,10 +118,9 @@ class CosineLrConfig(base_config.Config):
Attributes:
name: The name of the learning rate schedule. Defaults to CosineDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to
None.
decay_steps: A positive integer that is used for decay computation.
Defaults to None.
initial_learning_rate: A float. The initial learning rate. Defaults to None.
decay_steps: A positive integer that is used for decay computation. Defaults
to None.
alpha: A float. Minimum learning rate value as a fraction of
initial_learning_rate.
"""
...
...
@@ -173,4 +167,3 @@ class PolynomialWarmupConfig(base_config.Config):
name
:
str
=
'polynomial'
power
:
float
=
1
warmup_steps
:
Optional
[
int
]
=
None
official/modeling/optimization/configs/optimization_config_test.py
View file @
88253ce5
...
...
@@ -50,12 +50,11 @@ class OptimizerConfigTest(tf.test.TestCase):
'type'
:
'linear'
}
})
self
.
assertEqual
(
opt_config
.
optimizer
.
get
(),
opt_cfg
.
SGDConfig
())
self
.
assertEqual
(
opt_config
.
optimizer
.
get
(),
opt_cfg
.
SGDConfig
())
self
.
assertEqual
(
opt_config
.
learning_rate
.
get
(),
lr_cfg
.
PolynomialLrConfig
())
self
.
assertEqual
(
opt_config
.
warmup
.
get
(),
lr_cfg
.
LinearWarmupConfig
())
self
.
assertEqual
(
opt_config
.
warmup
.
get
(),
lr_cfg
.
LinearWarmupConfig
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/optimization/configs/optimizer_config.py
View file @
88253ce5
...
...
@@ -123,12 +123,11 @@ class LAMBConfig(base_config.Config):
epsilon: epsilon value used for numerical stability in LAMB optimizer.
weight_decay_rate: float. Weight decay rate. Default to 0.
exclude_from_weight_decay: List of regex patterns of variables excluded from
weight decay. Variables whose name contain a
substring matching the
pattern will be excluded.
weight decay. Variables whose name contain a
substring matching the
pattern will be excluded.
exclude_from_layer_adaptation: List of regex patterns of variables excluded
from layer adaptation. Variables whose name
contain a substring matching the pattern will
be excluded.
from layer adaptation. Variables whose name contain a substring matching
the pattern will be excluded.
"""
name
:
str
=
"LAMB"
beta_1
:
float
=
0.9
...
...
official/modeling/optimization/optimizer_factory.py
View file @
88253ce5
...
...
@@ -131,8 +131,9 @@ class OptimizerFactory(object):
rate built using self.build_lr() is passed as an argument to this method.
Args:
lr: A floating point value, or
a tf.keras.optimizers.schedules.LearningRateSchedule instance.
lr: A floating point value, or a
tf.keras.optimizers.schedules.LearningRateSchedule instance.
Returns:
tf.keras.optimizers.Optimizer instance.
"""
...
...
@@ -142,4 +143,3 @@ class OptimizerFactory(object):
optimizer
=
OPTIMIZERS_CLS
[
self
.
_optimizer_type
](
**
optimizer_dict
)
return
optimizer
official/modeling/optimization/optimizer_factory_test.py
View file @
88253ce5
...
...
@@ -25,12 +25,7 @@ from official.modeling.optimization.configs import optimization_config
class
OptimizerFactoryTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
(
'sgd'
),
(
'rmsprop'
),
(
'adam'
),
(
'adamw'
),
(
'lamb'
))
@
parameterized
.
parameters
((
'sgd'
),
(
'rmsprop'
),
(
'adam'
),
(
'adamw'
),
(
'lamb'
))
def
test_optimizers
(
self
,
optimizer_type
):
params
=
{
'optimizer'
:
{
...
...
@@ -56,20 +51,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
(
expected_optimizer_config
,
optimizer
.
get_config
())
def
test_missing_types
(
self
):
params
=
{
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
}
}
params
=
{
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}}}
with
self
.
assertRaises
(
ValueError
):
optimizer_factory
.
OptimizerFactory
(
optimization_config
.
OptimizationConfig
(
params
))
params
=
{
'learning_rate'
:
{
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
10000
,
20000
],
'values'
:
[
0.1
,
0.01
,
0.001
]}
'stepwise'
:
{
'boundaries'
:
[
10000
,
20000
],
'values'
:
[
0.1
,
0.01
,
0.001
]
}
}
}
with
self
.
assertRaises
(
ValueError
):
...
...
@@ -80,22 +72,20 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
'learning_rate'
:
{
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
10000
,
20000
],
'values'
:
[
0.1
,
0.01
,
0.001
]}
'stepwise'
:
{
'boundaries'
:
[
10000
,
20000
],
'values'
:
[
0.1
,
0.01
,
0.001
]
}
}
expected_lr_step_values
=
[
[
0
,
0.1
],
[
5000
,
0.1
],
[
10000
,
0.1
],
[
10001
,
0.01
],
[
20000
,
0.01
],
[
20001
,
0.001
]
]
}
expected_lr_step_values
=
[[
0
,
0.1
],
[
5000
,
0.1
],
[
10000
,
0.1
],
[
10001
,
0.01
],
[
20000
,
0.01
],
[
20001
,
0.001
]]
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
...
...
@@ -107,28 +97,28 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
'learning_rate'
:
{
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
10000
,
20000
],
'values'
:
[
0.1
,
0.01
,
0.001
]}
'stepwise'
:
{
'boundaries'
:
[
10000
,
20000
],
'values'
:
[
0.1
,
0.01
,
0.001
]
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
500
,
'warmup_learning_rate'
:
0.01
}
'linear'
:
{
'warmup_steps'
:
500
,
'warmup_learning_rate'
:
0.01
}
}
expected_lr_step_values
=
[
[
0
,
0.01
],
[
250
,
0.055
],
[
500
,
0.1
],
[
5500
,
0.1
],
[
10000
,
0.1
],
[
10001
,
0.01
],
[
20000
,
0.01
],
[
20001
,
0.001
]
]
}
expected_lr_step_values
=
[[
0
,
0.01
],
[
250
,
0.055
],
[
500
,
0.1
],
[
5500
,
0.1
],
[
10000
,
0.1
],
[
10001
,
0.01
],
[
20000
,
0.01
],
[
20001
,
0.001
]]
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
...
...
@@ -140,7 +130,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
'learning_rate'
:
{
'type'
:
'exponential'
,
...
...
@@ -170,7 +162,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
...
...
@@ -194,7 +188,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
...
...
@@ -204,11 +200,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
}
}
}
expected_lr_step_values
=
[[
0
,
0.1
],
[
250
,
0.08535534
],
[
500
,
0.04999999
],
[
750
,
0.01464466
],
[
1000
,
0
]]
expected_lr_step_values
=
[[
0
,
0.1
],
[
250
,
0.08535534
],
[
500
,
0.04999999
],
[
750
,
0.01464466
],
[
1000
,
0
]]
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
...
...
@@ -220,7 +213,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
'learning_rate'
:
{
'type'
:
'constant'
,
...
...
@@ -250,28 +245,28 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
'learning_rate'
:
{
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
10000
,
20000
],
'values'
:
[
0.1
,
0.01
,
0.001
]}
'stepwise'
:
{
'boundaries'
:
[
10000
,
20000
],
'values'
:
[
0.1
,
0.01
,
0.001
]
}
},
'warmup'
:
{
'type'
:
'polynomial'
,
'polynomial'
:
{
'warmup_steps'
:
500
,
'power'
:
2.
}
'polynomial'
:
{
'warmup_steps'
:
500
,
'power'
:
2.
}
}
expected_lr_step_values
=
[
[
0
,
0.0
],
[
250
,
0.025
],
[
500
,
0.1
],
[
5500
,
0.1
],
[
10000
,
0.1
],
[
10001
,
0.01
],
[
20000
,
0.01
],
[
20001
,
0.001
]
]
}
expected_lr_step_values
=
[[
0
,
0.0
],
[
250
,
0.025
],
[
500
,
0.1
],
[
5500
,
0.1
],
[
10000
,
0.1
],
[
10001
,
0.01
],
[
20000
,
0.01
],
[
20001
,
0.001
]]
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
...
...
@@ -279,5 +274,6 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for
step
,
value
in
expected_lr_step_values
:
self
.
assertAlmostEqual
(
lr
(
step
).
numpy
(),
value
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/performance.py
View file @
88253ce5
...
...
@@ -21,7 +21,7 @@ import tensorflow as tf
def
configure_optimizer
(
optimizer
,
use_float16
=
False
,
use_graph_rewrite
=
False
,
loss_scale
=
"
dynamic
"
):
loss_scale
=
'
dynamic
'
):
"""Configures optimizer object with performance options."""
if
use_float16
:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
...
...
@@ -47,10 +47,9 @@ def set_mixed_precision_policy(dtype, loss_scale=None):
'mixed_float16'
,
loss_scale
=
loss_scale
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
elif
dtype
==
tf
.
bfloat16
:
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'mixed_bfloat16'
)
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'mixed_bfloat16'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
elif
dtype
==
tf
.
float32
:
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
else
:
raise
ValueError
(
"
Unexpected dtype: %s
"
%
dtype
)
raise
ValueError
(
'
Unexpected dtype: %s
'
%
dtype
)
official/modeling/tf_utils.py
View file @
88253ce5
...
...
@@ -29,8 +29,7 @@ from official.modeling import activations
None
,
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
"input tensors. pack/unpack inputs to override __call__ is no longer "
"needed."
)
"needed."
)
def
pack_inputs
(
inputs
):
"""Pack a list of `inputs` tensors to a tuple.
...
...
@@ -55,8 +54,7 @@ def pack_inputs(inputs):
None
,
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
"input tensors. pack/unpack inputs to override __call__ is no longer "
"needed."
)
"needed."
)
def
unpack_inputs
(
inputs
):
"""unpack a tuple of `inputs` tensors to a tuple.
...
...
official/modeling/training/distributed_executor.py
View file @
88253ce5
...
...
@@ -133,15 +133,9 @@ class SummaryWriter(object):
class
DistributedExecutor
(
object
):
"""Interface to train and eval models with tf.distribute.Strategy.
"""
"""Interface to train and eval models with tf.distribute.Strategy."""
def
__init__
(
self
,
strategy
,
params
,
model_fn
,
loss_fn
,
is_multi_host
=
False
):
def
__init__
(
self
,
strategy
,
params
,
model_fn
,
loss_fn
,
is_multi_host
=
False
):
"""Constructor.
Args:
...
...
@@ -293,8 +287,7 @@ class DistributedExecutor(object):
raise
ValueError
(
'steps should be an Tensor. Python object may cause '
'retracing.'
)
per_replica_losses
=
strategy
.
run
(
replicated_step
,
args
=
(
next
(
iterator
),))
per_replica_losses
=
strategy
.
run
(
replicated_step
,
args
=
(
next
(
iterator
),))
for
_
in
tf
.
range
(
num_steps
-
1
):
per_replica_losses
=
strategy
.
run
(
replicated_step
,
args
=
(
next
(
iterator
),))
...
...
@@ -368,6 +361,7 @@ class DistributedExecutor(object):
available checkpoints. If `False`, will do the evaluation once after the
final step.
save_config: bool. Whether to save params to model_dir.
Returns:
The training loss and eval metrics.
"""
...
...
@@ -477,16 +471,15 @@ class DistributedExecutor(object):
# Step-0 operations
if
current_step
==
0
and
not
latest_checkpoint_file
:
_save_checkpoint
(
checkpoint
,
model_dir
,
checkpoint_name
.
format
(
step
=
current_step
))
_save_checkpoint
(
checkpoint
,
model_dir
,
checkpoint_name
.
format
(
step
=
current_step
))
if
test_step
:
eval_iterator
=
self
.
_get_input_iterator
(
eval_input_fn
,
strategy
)
eval_metric_result
=
self
.
_run_evaluation
(
test_step
,
current_step
,
eval_metric
,
eval_iterator
)
logging
.
info
(
'Step: %s evalation metric = %s.'
,
current_step
,
eval_metric_result
)
test_summary_writer
(
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
eval_metric_result
=
self
.
_run_evaluation
(
test_step
,
current_step
,
eval_metric
,
eval_iterator
)
logging
.
info
(
'Step: %s evalation metric = %s.'
,
current_step
,
eval_metric_result
)
test_summary_writer
(
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
reset_states
(
eval_metric
)
logging
.
info
(
'Training started'
)
...
...
@@ -519,8 +512,7 @@ class DistributedExecutor(object):
else
:
train_metric_result
.
update
({
'learning_rate'
:
optimizer
.
lr
.
numpy
()})
logging
.
info
(
'Train Step: %d/%d / loss = %s / training metric = %s'
,
current_step
,
total_steps
,
train_loss
,
train_metric_result
)
current_step
,
total_steps
,
train_loss
,
train_metric_result
)
train_summary_writer
(
metrics
=
train_metric_result
,
step
=
optimizer
.
iterations
)
...
...
@@ -561,8 +553,7 @@ class DistributedExecutor(object):
eval_metric_result
=
self
.
_run_evaluation
(
test_step
,
current_step
,
eval_metric
,
eval_iterator
)
logging
.
info
(
'Final evaluation metric = %s.'
,
eval_metric_result
)
test_summary_writer
(
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
test_summary_writer
(
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
self
.
train_summary_writer
.
close
()
self
.
eval_summary_writer
.
close
()
...
...
@@ -696,8 +687,7 @@ class DistributedExecutor(object):
reader
=
tf
.
compat
.
v1
.
train
.
NewCheckpointReader
(
checkpoint_path
)
current_step
=
reader
.
get_tensor
(
'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE'
)
logging
.
info
(
'Checkpoint file %s found and restoring from '
logging
.
info
(
'Checkpoint file %s found and restoring from '
'checkpoint'
,
checkpoint_path
)
status
=
checkpoint
.
restore
(
checkpoint_path
)
status
.
expect_partial
().
assert_existing_objects_matched
()
...
...
@@ -755,8 +745,8 @@ class ExecutorBuilder(object):
"""
def
__init__
(
self
,
strategy_type
=
None
,
strategy_config
=
None
):
_
=
distribution_utils
.
configure_cluster
(
strategy_config
.
worker_hosts
,
strategy_config
.
task_index
)
_
=
distribution_utils
.
configure_cluster
(
strategy_config
.
worker_hosts
,
strategy_config
.
task_index
)
"""Constructor.
Args:
...
...
official/nlp/albert/configs.py
View file @
88253ce5
...
...
@@ -26,10 +26,7 @@ from official.nlp.bert import configs
class
AlbertConfig
(
configs
.
BertConfig
):
"""Configuration for `ALBERT`."""
def
__init__
(
self
,
num_hidden_groups
=
1
,
inner_group_num
=
1
,
**
kwargs
):
def
__init__
(
self
,
num_hidden_groups
=
1
,
inner_group_num
=
1
,
**
kwargs
):
"""Constructs AlbertConfig.
Args:
...
...
official/nlp/albert/export_albert_tfhub.py
View file @
88253ce5
...
...
@@ -18,6 +18,7 @@ from __future__ import division
# from __future__ import google_type_annotations
from
__future__
import
print_function
# Import libraries
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
...
...
official/nlp/albert/run_classifier.py
View file @
88253ce5
...
...
@@ -21,6 +21,7 @@ from __future__ import print_function
import
json
import
os
# Import libraries
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
...
...
Prev
1
2
3
4
5
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment