Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
999fae62
Commit
999fae62
authored
Aug 12, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Aug 12, 2020
Browse files
Internal change
PiperOrigin-RevId: 326286926
parent
94561082
Changes
205
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
324 additions
and
259 deletions
+324
-259
official/core/base_task.py
official/core/base_task.py
+12
-9
official/core/base_trainer.py
official/core/base_trainer.py
+7
-5
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+13
-10
official/modeling/hyperparams/base_config.py
official/modeling/hyperparams/base_config.py
+5
-4
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+10
-11
official/modeling/hyperparams/oneof.py
official/modeling/hyperparams/oneof.py
+4
-8
official/modeling/hyperparams/oneof_test.py
official/modeling/hyperparams/oneof_test.py
+12
-6
official/modeling/hyperparams/params_dict.py
official/modeling/hyperparams/params_dict.py
+33
-23
official/modeling/hyperparams/params_dict_test.py
official/modeling/hyperparams/params_dict_test.py
+113
-41
official/modeling/optimization/configs/learning_rate_config.py
...ial/modeling/optimization/configs/learning_rate_config.py
+18
-25
official/modeling/optimization/configs/optimization_config_test.py
...modeling/optimization/configs/optimization_config_test.py
+3
-4
official/modeling/optimization/configs/optimizer_config.py
official/modeling/optimization/configs/optimizer_config.py
+8
-9
official/modeling/optimization/optimizer_factory.py
official/modeling/optimization/optimizer_factory.py
+3
-3
official/modeling/optimization/optimizer_factory_test.py
official/modeling/optimization/optimizer_factory_test.py
+58
-62
official/modeling/performance.py
official/modeling/performance.py
+3
-4
official/modeling/tf_utils.py
official/modeling/tf_utils.py
+2
-4
official/modeling/training/distributed_executor.py
official/modeling/training/distributed_executor.py
+17
-27
official/nlp/albert/configs.py
official/nlp/albert/configs.py
+1
-4
official/nlp/albert/export_albert_tfhub.py
official/nlp/albert/export_albert_tfhub.py
+1
-0
official/nlp/albert/run_classifier.py
official/nlp/albert/run_classifier.py
+1
-0
No files found.
official/core/base_task.py
View file @
999fae62
...
@@ -171,26 +171,30 @@ class Task(tf.Module):
...
@@ -171,26 +171,30 @@ class Task(tf.Module):
return
[]
return
[]
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
"""Process and update metrics. Called when using custom training loop API.
"""Process and update metrics.
Called when using custom training loop API.
Args:
Args:
metrics: a nested structure of metrics objects.
metrics: a nested structure of metrics objects.
The return of function
The return of function
self.build_metrics.
self.build_metrics.
labels: a tensor or a nested structure of tensors.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
For example,
For example,
output of the keras model built by self.build_model.
output of the keras model built by self.build_model.
"""
"""
for
metric
in
metrics
:
for
metric
in
metrics
:
metric
.
update_state
(
labels
,
model_outputs
)
metric
.
update_state
(
labels
,
model_outputs
)
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
"""Process and update compiled_metrics. call when using compile/fit API.
"""Process and update compiled_metrics.
call when using compile/fit API.
Args:
Args:
compiled_metrics: the compiled metrics (model.compiled_metrics).
compiled_metrics: the compiled metrics (model.compiled_metrics).
labels: a tensor or a nested structure of tensors.
labels: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
model_outputs: a tensor or a nested structure of tensors.
For example,
For example,
output of the keras model built by self.build_model.
output of the keras model built by self.build_model.
"""
"""
compiled_metrics
.
update_state
(
labels
,
model_outputs
)
compiled_metrics
.
update_state
(
labels
,
model_outputs
)
...
@@ -297,4 +301,3 @@ class Task(tf.Module):
...
@@ -297,4 +301,3 @@ class Task(tf.Module):
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
"""Optional reduce of aggregated logs over validation steps."""
"""Optional reduce of aggregated logs over validation steps."""
return
{}
return
{}
official/core/base_trainer.py
View file @
999fae62
...
@@ -19,6 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and
...
@@ -19,6 +19,7 @@ The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
interchangable and independent on model architectures and tasks.
"""
"""
import
gin
import
gin
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -28,7 +29,6 @@ from official.modeling import optimization
...
@@ -28,7 +29,6 @@ from official.modeling import optimization
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.modeling.hyperparams
import
config_definitions
from
official.modeling.hyperparams
import
config_definitions
ExperimentConfig
=
config_definitions
.
ExperimentConfig
ExperimentConfig
=
config_definitions
.
ExperimentConfig
...
@@ -52,8 +52,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
...
@@ -52,8 +52,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
default to True.
default to True.
evaluate: bool, whether or not this trainer will be used for evaluation.
evaluate: bool, whether or not this trainer will be used for evaluation.
default to True.
default to True.
model: tf.keras.Model instance. If provided, it will be used instead
model: tf.keras.Model instance. If provided, it will be used instead
of
of
building model using task.build_model(). Default to None.
building model using task.build_model(). Default to None.
optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will
optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will
used instead of the optimizer from config. Default to None.
used instead of the optimizer from config. Default to None.
"""
"""
...
@@ -90,8 +90,10 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
...
@@ -90,8 +90,10 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
else
:
else
:
checkpoint_items
=
{}
checkpoint_items
=
{}
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
model
=
self
.
model
,
global_step
=
self
.
global_step
,
optimizer
=
self
.
optimizer
,
**
checkpoint_items
)
model
=
self
.
model
,
optimizer
=
self
.
optimizer
,
**
checkpoint_items
)
self
.
_train_loss
=
tf
.
keras
.
metrics
.
Mean
(
'training_loss'
,
dtype
=
tf
.
float32
)
self
.
_train_loss
=
tf
.
keras
.
metrics
.
Mean
(
'training_loss'
,
dtype
=
tf
.
float32
)
self
.
_validation_loss
=
tf
.
keras
.
metrics
.
Mean
(
self
.
_validation_loss
=
tf
.
keras
.
metrics
.
Mean
(
...
...
official/core/base_trainer_test.py
View file @
999fae62
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
# ==============================================================================
# ==============================================================================
"""Tests for tensorflow_models.core.trainers.trainer."""
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
# pylint: disable=g-direct-tensorflow-import
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -42,13 +43,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -42,13 +43,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
super
().
setUp
()
super
().
setUp
()
self
.
_config
=
cfg
.
ExperimentConfig
(
self
.
_config
=
cfg
.
ExperimentConfig
(
trainer
=
cfg
.
TrainerConfig
(
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
cfg
.
OptimizationConfig
(
optimizer_config
=
cfg
.
OptimizationConfig
(
{
{
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'sgd'
'type'
:
'sgd'
},
},
'learning_rate'
:
{
'learning_rate'
:
{
'type'
:
'constant'
'type'
:
'constant'
}})))
}
})))
def
create_test_trainer
(
self
):
def
create_test_trainer
(
self
):
task
=
mock_task
.
MockTask
()
task
=
mock_task
.
MockTask
()
...
@@ -81,13 +83,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -81,13 +83,14 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
runtime
=
cfg
.
RuntimeConfig
(
runtime
=
cfg
.
RuntimeConfig
(
mixed_precision_dtype
=
mixed_precision_dtype
,
loss_scale
=
loss_scale
),
mixed_precision_dtype
=
mixed_precision_dtype
,
loss_scale
=
loss_scale
),
trainer
=
cfg
.
TrainerConfig
(
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
cfg
.
OptimizationConfig
(
optimizer_config
=
cfg
.
OptimizationConfig
(
{
{
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'sgd'
'type'
:
'sgd'
},
},
'learning_rate'
:
{
'learning_rate'
:
{
'type'
:
'constant'
'type'
:
'constant'
}})))
}
})))
task
=
mock_task
.
MockTask
()
task
=
mock_task
.
MockTask
()
trainer
=
trainer_lib
.
Trainer
(
config
,
task
)
trainer
=
trainer_lib
.
Trainer
(
config
,
task
)
if
mixed_precision_dtype
!=
'float16'
:
if
mixed_precision_dtype
!=
'float16'
:
...
...
official/modeling/hyperparams/base_config.py
View file @
999fae62
...
@@ -136,10 +136,11 @@ class Config(params_dict.ParamsDict):
...
@@ -136,10 +136,11 @@ class Config(params_dict.ParamsDict):
return
subconfig_type
return
subconfig_type
def
__post_init__
(
self
,
default_params
,
restrictions
,
*
args
,
**
kwargs
):
def
__post_init__
(
self
,
default_params
,
restrictions
,
*
args
,
**
kwargs
):
super
().
__init__
(
default_params
=
default_params
,
super
().
__init__
(
restrictions
=
restrictions
,
default_params
=
default_params
,
*
args
,
restrictions
=
restrictions
,
**
kwargs
)
*
args
,
**
kwargs
)
def
_set
(
self
,
k
,
v
):
def
_set
(
self
,
k
,
v
):
"""Overrides same method in ParamsDict.
"""Overrides same method in ParamsDict.
...
...
official/modeling/hyperparams/config_definitions.py
View file @
999fae62
...
@@ -55,14 +55,14 @@ class DataConfig(base_config.Config):
...
@@ -55,14 +55,14 @@ class DataConfig(base_config.Config):
exhaust all the examples in the dataset.
exhaust all the examples in the dataset.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
the
the
returned tf.data.Dataset will have a 2-tuple structure (input, label)
returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default,
according to builder.info.supervised_keys; if False, the default,
the
the
returned tf.data.Dataset will have a dictionary with all the features.
returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped
tfds_skip_decoding_feature: A str to indicate which features are skipped
for
for
decoding when loading dataset from TFDS. Use comma to separate
decoding when loading dataset from TFDS. Use comma to separate
multiple
multiple
features. The main use case is to skip the image/video decoding
features. The main use case is to skip the image/video decoding
for better
for better
performance.
performance.
"""
"""
input_path
:
str
=
""
input_path
:
str
=
""
tfds_name
:
str
=
""
tfds_name
:
str
=
""
...
@@ -177,8 +177,8 @@ class TrainerConfig(base_config.Config):
...
@@ -177,8 +177,8 @@ class TrainerConfig(base_config.Config):
checkpoint_interval: number of steps between checkpoints.
checkpoint_interval: number of steps between checkpoints.
max_to_keep: max checkpoints to keep.
max_to_keep: max checkpoints to keep.
continuous_eval_timeout: maximum number of seconds to wait between
continuous_eval_timeout: maximum number of seconds to wait between
checkpoints, if set to None, continuous eval will wait indefinitely.
checkpoints, if set to None, continuous eval will wait indefinitely.
This
This
is only used continuous_train_and_eval and continuous_eval modes.
is only used continuous_train_and_eval and continuous_eval modes.
train_steps: number of train steps.
train_steps: number of train steps.
validation_steps: number of eval steps. If `None`, the entire eval dataset
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
is used.
...
@@ -217,4 +217,3 @@ class ExperimentConfig(base_config.Config):
...
@@ -217,4 +217,3 @@ class ExperimentConfig(base_config.Config):
task
:
TaskConfig
=
TaskConfig
()
task
:
TaskConfig
=
TaskConfig
()
trainer
:
TrainerConfig
=
TrainerConfig
()
trainer
:
TrainerConfig
=
TrainerConfig
()
runtime
:
RuntimeConfig
=
RuntimeConfig
()
runtime
:
RuntimeConfig
=
RuntimeConfig
()
official/modeling/hyperparams/oneof.py
View file @
999fae62
...
@@ -38,15 +38,12 @@ class OneOfConfig(base_config.Config):
...
@@ -38,15 +38,12 @@ class OneOfConfig(base_config.Config):
if
self
.
type
is
None
:
if
self
.
type
is
None
:
return
{
'type'
:
None
}
return
{
'type'
:
None
}
elif
self
.
__dict__
[
'type'
]
not
in
self
.
__dict__
:
elif
self
.
__dict__
[
'type'
]
not
in
self
.
__dict__
:
raise
ValueError
(
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
'type: {!r} is not a valid key!'
.
format
(
self
.
__dict__
[
'type'
]))
self
.
__dict__
[
'type'
]))
else
:
else
:
chosen_type
=
self
.
type
chosen_type
=
self
.
type
chosen_value
=
self
.
__dict__
[
chosen_type
]
chosen_value
=
self
.
__dict__
[
chosen_type
]
return
{
return
{
'type'
:
self
.
type
,
chosen_type
:
self
.
_export_config
(
chosen_value
)}
'type'
:
self
.
type
,
chosen_type
:
self
.
_export_config
(
chosen_value
)
}
def
get
(
self
):
def
get
(
self
):
"""Returns selected config based on the value of type.
"""Returns selected config based on the value of type.
...
@@ -57,6 +54,5 @@ class OneOfConfig(base_config.Config):
...
@@ -57,6 +54,5 @@ class OneOfConfig(base_config.Config):
if
chosen_type
is
None
:
if
chosen_type
is
None
:
return
None
return
None
if
chosen_type
not
in
self
.
__dict__
:
if
chosen_type
not
in
self
.
__dict__
:
raise
ValueError
(
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
type
))
'type: {!r} is not a valid key!'
.
format
(
self
.
type
))
return
self
.
__dict__
[
chosen_type
]
return
self
.
__dict__
[
chosen_type
]
official/modeling/hyperparams/oneof_test.py
View file @
999fae62
...
@@ -48,12 +48,18 @@ class Network(base_config.Config):
...
@@ -48,12 +48,18 @@ class Network(base_config.Config):
class
OneOfTest
(
tf
.
test
.
TestCase
):
class
OneOfTest
(
tf
.
test
.
TestCase
):
def
test_to_dict
(
self
):
def
test_to_dict
(
self
):
network_params
=
{
'backbone'
:
{
'type'
:
'resnet'
,
network_params
=
{
'resnet'
:
{
'model_depth'
:
50
}
'backbone'
:
{
},
'type'
:
'resnet'
,
'output_layer'
:
{
'type'
:
'single'
,
'resnet'
:
{
'single'
:
1000
}
'model_depth'
:
50
}
}
},
'output_layer'
:
{
'type'
:
'single'
,
'single'
:
1000
}
}
network_config
=
Network
(
network_params
)
network_config
=
Network
(
network_params
)
self
.
assertEqual
(
network_config
.
as_dict
(),
network_params
)
self
.
assertEqual
(
network_config
.
as_dict
(),
network_params
)
...
...
official/modeling/hyperparams/params_dict.py
View file @
999fae62
...
@@ -30,7 +30,8 @@ import yaml
...
@@ -30,7 +30,8 @@ import yaml
# key-value pair string. It splits each k-v pair on the = sign, and
# key-value pair string. It splits each k-v pair on the = sign, and
# matches on values that are within single quotes, double quotes, single
# matches on values that are within single quotes, double quotes, single
# values (e.g. floats, ints, etc.), and a lists within brackets.
# values (e.g. floats, ints, etc.), and a lists within brackets.
_PARAM_RE
=
re
.
compile
(
r
"""
_PARAM_RE
=
re
.
compile
(
r
"""
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
\s*=\s*
\s*=\s*
((?P<val>\'(.*?)\' # single quote
((?P<val>\'(.*?)\' # single quote
...
@@ -138,8 +139,8 @@ class ParamsDict(object):
...
@@ -138,8 +139,8 @@ class ParamsDict(object):
ValueError: if the ParamsDict instance has been locked.
ValueError: if the ParamsDict instance has been locked.
"""
"""
if
k
in
ParamsDict
.
RESERVED_ATTR
:
if
k
in
ParamsDict
.
RESERVED_ATTR
:
raise
AttributeError
(
'The key `{}` is reserved. No change is allowes. '
raise
AttributeError
(
.
format
(
k
))
'The key `{}` is reserved. No change is allowes. '
.
format
(
k
))
if
k
not
in
self
.
__dict__
.
keys
():
if
k
not
in
self
.
__dict__
.
keys
():
raise
AttributeError
(
'The key `{}` does not exist. '
.
format
(
k
))
raise
AttributeError
(
'The key `{}` does not exist. '
.
format
(
k
))
if
self
.
_locked
:
if
self
.
_locked
:
...
@@ -150,13 +151,13 @@ class ParamsDict(object):
...
@@ -150,13 +151,13 @@ class ParamsDict(object):
"""Override the ParamsDict with a set of given params.
"""Override the ParamsDict with a set of given params.
Args:
Args:
override_params: a dict or a ParamsDict specifying the parameters to
override_params: a dict or a ParamsDict specifying the parameters to
be
be
overridden.
overridden.
is_strict: a boolean specifying whether override is strict or not. If
is_strict: a boolean specifying whether override is strict or not. If
True, keys in `override_params` must be present in the ParamsDict.
True, keys in `override_params` must be present in the ParamsDict.
If
If
False, keys in `override_params` can be different from what is
False, keys in `override_params` can be different from what is
currently
currently
defined in the ParamsDict. In this case, the ParamsDict will
defined in the ParamsDict. In this case, the ParamsDict will
be extended
be extended
to include the new keys.
to include the new keys.
"""
"""
if
self
.
_locked
:
if
self
.
_locked
:
raise
ValueError
(
'The ParamsDict has been locked. No change is allowed.'
)
raise
ValueError
(
'The ParamsDict has been locked. No change is allowed.'
)
...
@@ -240,6 +241,7 @@ class ParamsDict(object):
...
@@ -240,6 +241,7 @@ class ParamsDict(object):
(2) any inconsistency violating the restriction is found.
(2) any inconsistency violating the restriction is found.
ValueError: if the restriction defined in the string is not supported.
ValueError: if the restriction defined in the string is not supported.
"""
"""
def
_get_kv
(
dotted_string
,
params_dict
):
def
_get_kv
(
dotted_string
,
params_dict
):
"""Get keys and values indicated by dotted_string."""
"""Get keys and values indicated by dotted_string."""
if
_CONST_VALUE_RE
.
match
(
dotted_string
)
is
not
None
:
if
_CONST_VALUE_RE
.
match
(
dotted_string
)
is
not
None
:
...
@@ -270,38 +272,44 @@ class ParamsDict(object):
...
@@ -270,38 +272,44 @@ class ParamsDict(object):
tokens
=
restriction
.
split
(
'=='
)
tokens
=
restriction
.
split
(
'=='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
!=
right_v
:
if
left_v
!=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
raise
KeyError
(
.
format
(
tokens
[
0
],
tokens
[
1
]))
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'!='
in
restriction
:
elif
'!='
in
restriction
:
tokens
=
restriction
.
split
(
'!='
)
tokens
=
restriction
.
split
(
'!='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
==
right_v
:
if
left_v
==
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
raise
KeyError
(
.
format
(
tokens
[
0
],
tokens
[
1
]))
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<'
in
restriction
:
elif
'<'
in
restriction
:
tokens
=
restriction
.
split
(
'<'
)
tokens
=
restriction
.
split
(
'<'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>=
right_v
:
if
left_v
>=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
raise
KeyError
(
.
format
(
tokens
[
0
],
tokens
[
1
]))
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<='
in
restriction
:
elif
'<='
in
restriction
:
tokens
=
restriction
.
split
(
'<='
)
tokens
=
restriction
.
split
(
'<='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>
right_v
:
if
left_v
>
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
raise
KeyError
(
.
format
(
tokens
[
0
],
tokens
[
1
]))
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>'
in
restriction
:
elif
'>'
in
restriction
:
tokens
=
restriction
.
split
(
'>'
)
tokens
=
restriction
.
split
(
'>'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<=
right_v
:
if
left_v
<=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
raise
KeyError
(
.
format
(
tokens
[
0
],
tokens
[
1
]))
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>='
in
restriction
:
elif
'>='
in
restriction
:
tokens
=
restriction
.
split
(
'>='
)
tokens
=
restriction
.
split
(
'>='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<
right_v
:
if
left_v
<
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
raise
KeyError
(
.
format
(
tokens
[
0
],
tokens
[
1
]))
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
else
:
else
:
raise
ValueError
(
'Unsupported relation in restriction.'
)
raise
ValueError
(
'Unsupported relation in restriction.'
)
...
@@ -316,10 +324,12 @@ def read_yaml_to_params_dict(file_path):
...
@@ -316,10 +324,12 @@ def read_yaml_to_params_dict(file_path):
def
save_params_dict_to_yaml
(
params
,
file_path
):
def
save_params_dict_to_yaml
(
params
,
file_path
):
"""Saves the input ParamsDict to a YAML file."""
"""Saves the input ParamsDict to a YAML file."""
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'w'
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'w'
)
as
f
:
def
_my_list_rep
(
dumper
,
data
):
def
_my_list_rep
(
dumper
,
data
):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return
dumper
.
represent_sequence
(
return
dumper
.
represent_sequence
(
u
'tag:yaml.org,2002:seq'
,
data
,
flow_style
=
True
)
u
'tag:yaml.org,2002:seq'
,
data
,
flow_style
=
True
)
yaml
.
add_representer
(
list
,
_my_list_rep
)
yaml
.
add_representer
(
list
,
_my_list_rep
)
yaml
.
dump
(
params
.
as_dict
(),
f
,
default_flow_style
=
False
)
yaml
.
dump
(
params
.
as_dict
(),
f
,
default_flow_style
=
False
)
...
@@ -408,8 +418,8 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
...
@@ -408,8 +418,8 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
Args:
Args:
params: a ParamsDict object to be overridden.
params: a ParamsDict object to be overridden.
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or
path to
path to
a YAML file specifying the parameters to be overridden.
a YAML file specifying the parameters to be overridden.
is_strict: a boolean specifying whether override is strict or not.
is_strict: a boolean specifying whether override is strict or not.
Returns:
Returns:
...
...
official/modeling/hyperparams/params_dict_test.py
View file @
999fae62
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Tests for params_dict.py."""
"""Tests for params_dict.py."""
import
os
import
os
...
@@ -56,8 +55,7 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -56,8 +55,7 @@ class ParamsDictTest(tf.test.TestCase):
def
test_setattr
(
self
):
def
test_setattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
=
params_dict
.
ParamsDict
()
params
.
override
(
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
params
.
c
=
'ccc'
params
.
c
=
'ccc'
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
b
,
2
)
...
@@ -65,17 +63,23 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -65,17 +63,23 @@ class ParamsDictTest(tf.test.TestCase):
def
test_getattr
(
self
):
def
test_getattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
=
params_dict
.
ParamsDict
()
params
.
override
(
params
.
override
({
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
c
,
None
)
self
.
assertEqual
(
params
.
c
,
None
)
def
test_delattr
(
self
):
def
test_delattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
=
params_dict
.
ParamsDict
()
params
.
override
(
params
.
override
({
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
,
'd'
:
{
'd1'
:
1
,
'd2'
:
10
}},
'a'
:
'aa'
,
is_strict
=
False
)
'b'
:
2
,
'c'
:
None
,
'd'
:
{
'd1'
:
1
,
'd2'
:
10
}
},
is_strict
=
False
)
del
params
.
c
del
params
.
c
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
b
,
2
)
...
@@ -87,22 +91,26 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -87,22 +91,26 @@ class ParamsDictTest(tf.test.TestCase):
def
test_contains
(
self
):
def
test_contains
(
self
):
params
=
params_dict
.
ParamsDict
()
params
=
params_dict
.
ParamsDict
()
params
.
override
(
params
.
override
({
'a'
:
'aa'
},
is_strict
=
False
)
{
'a'
:
'aa'
},
is_strict
=
False
)
self
.
assertIn
(
'a'
,
params
)
self
.
assertIn
(
'a'
,
params
)
self
.
assertNotIn
(
'b'
,
params
)
self
.
assertNotIn
(
'b'
,
params
)
def
test_get
(
self
):
def
test_get
(
self
):
params
=
params_dict
.
ParamsDict
()
params
=
params_dict
.
ParamsDict
()
params
.
override
(
params
.
override
({
'a'
:
'aa'
},
is_strict
=
False
)
{
'a'
:
'aa'
},
is_strict
=
False
)
self
.
assertEqual
(
params
.
get
(
'a'
),
'aa'
)
self
.
assertEqual
(
params
.
get
(
'a'
),
'aa'
)
self
.
assertEqual
(
params
.
get
(
'b'
,
2
),
2
)
self
.
assertEqual
(
params
.
get
(
'b'
,
2
),
2
)
self
.
assertEqual
(
params
.
get
(
'b'
),
None
)
self
.
assertEqual
(
params
.
get
(
'b'
),
None
)
def
test_override_is_strict_true
(
self
):
def
test_override_is_strict_true
(
self
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
'cc'
,
'c2'
:
20
}})
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
'cc'
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c1'
:
'ccc'
}},
is_strict
=
True
)
params
.
override
({
'a'
:
2
,
'c'
:
{
'c1'
:
'ccc'
}},
is_strict
=
True
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c1
,
'ccc'
)
self
.
assertEqual
(
params
.
c
.
c1
,
'ccc'
)
...
@@ -112,8 +120,14 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -112,8 +120,14 @@ class ParamsDictTest(tf.test.TestCase):
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
params
.
override
({
'c'
:
{
'c3'
:
30
}},
is_strict
=
True
)
def
test_override_is_strict_false
(
self
):
def
test_override_is_strict_false
(
self
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}})
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
params
.
override
({
'a'
:
2
,
'c'
:
{
'c3'
:
3000
}},
is_strict
=
False
)
params
.
override
({
'a'
:
2
,
'c'
:
{
'c3'
:
3000
}},
is_strict
=
False
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
a
,
2
)
self
.
assertEqual
(
params
.
c
.
c3
,
3000
)
self
.
assertEqual
(
params
.
c
.
c3
,
3000
)
...
@@ -123,8 +137,14 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -123,8 +137,14 @@ class ParamsDictTest(tf.test.TestCase):
self
.
assertEqual
(
params
.
c
.
c4
,
4444
)
self
.
assertEqual
(
params
.
c
.
c4
,
4444
)
def
test_as_dict
(
self
):
def
test_as_dict
(
self
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}})
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
params_d
=
params
.
as_dict
()
params_d
=
params
.
as_dict
()
self
.
assertEqual
(
params_d
[
'a'
],
'aa'
)
self
.
assertEqual
(
params_d
[
'a'
],
'aa'
)
self
.
assertEqual
(
params_d
[
'b'
],
2
)
self
.
assertEqual
(
params_d
[
'b'
],
2
)
...
@@ -134,21 +154,25 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -134,21 +154,25 @@ class ParamsDictTest(tf.test.TestCase):
def
test_validate
(
self
):
def
test_validate
(
self
):
# Raise error due to the unknown parameter.
# Raise error due to the unknown parameter.
with
self
.
assertRaises
(
KeyError
):
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'a'
:
11
}},
[
'a == c'
])
{
'a'
:
1
,
'b'
:
{
'a'
:
11
}},
[
'a == c'
])
# OK to check equality of two nested dicts.
# OK to check equality of two nested dicts.
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
1
,
'b'
:
{
'a'
:
10
},
'c'
:
{
'a'
:
10
}},
[
'b == c'
])
'a'
:
1
,
'b'
:
{
'a'
:
10
},
'c'
:
{
'a'
:
10
}
},
[
'b == c'
])
# Raise error due to inconsistency
# Raise error due to inconsistency
with
self
.
assertRaises
(
KeyError
):
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'c'
:
{
'a'
:
10
}},
[
'a == c.a'
])
{
'a'
:
1
,
'c'
:
{
'a'
:
10
}},
[
'a == c.a'
])
# Valid rule.
# Valid rule.
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'c'
:
{
'a'
:
1
}},
[
'a == c.a'
])
{
'a'
:
1
,
'c'
:
{
'a'
:
1
}},
[
'a == c.a'
])
# Overridding violates the existing rule, raise error upon validate.
# Overridding violates the existing rule, raise error upon validate.
params
.
override
({
'a'
:
11
})
params
.
override
({
'a'
:
11
})
...
@@ -156,12 +180,20 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -156,12 +180,20 @@ class ParamsDictTest(tf.test.TestCase):
params
.
validate
()
params
.
validate
()
# Valid restrictions with constant.
# Valid restrictions with constant.
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
None
,
'c'
:
{
'a'
:
1
}},
[
'a == None'
,
'c.a == 1'
])
'a'
:
None
,
'c'
:
{
'a'
:
1
}
},
[
'a == None'
,
'c.a == 1'
])
params
.
validate
()
params
.
validate
()
with
self
.
assertRaises
(
KeyError
):
with
self
.
assertRaises
(
KeyError
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
4
,
'c'
:
{
'a'
:
1
}},
[
'a == None'
,
'c.a == 1'
])
'a'
:
4
,
'c'
:
{
'a'
:
1
}
},
[
'a == None'
,
'c.a == 1'
])
class
ParamsDictIOTest
(
tf
.
test
.
TestCase
):
class
ParamsDictIOTest
(
tf
.
test
.
TestCase
):
...
@@ -173,8 +205,14 @@ class ParamsDictIOTest(tf.test.TestCase):
...
@@ -173,8 +205,14 @@ class ParamsDictIOTest(tf.test.TestCase):
return
temp_file
return
temp_file
def
test_save_params_dict_to_yaml
(
self
):
def
test_save_params_dict_to_yaml
(
self
):
params
=
params_dict
.
ParamsDict
(
params
=
params_dict
.
ParamsDict
({
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}})
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
{
'c1'
:
10
,
'c2'
:
20
}
})
output_yaml_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'params.yaml'
)
output_yaml_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'params.yaml'
)
params_dict
.
save_params_dict_to_yaml
(
params
,
output_yaml_file
)
params_dict
.
save_params_dict_to_yaml
(
params
,
output_yaml_file
)
...
@@ -203,7 +241,12 @@ class ParamsDictIOTest(tf.test.TestCase):
...
@@ -203,7 +241,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_dict
(
self
):
def
test_override_params_dict_using_dict
(
self
):
params
=
params_dict
.
ParamsDict
({
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_dict
=
{
'b'
:
5.2
,
'c'
:
[
30
,
40
]}
override_dict
=
{
'b'
:
5.2
,
'c'
:
[
30
,
40
]}
params
=
params_dict
.
override_params_dict
(
params
=
params_dict
.
override_params_dict
(
params
,
override_dict
,
is_strict
=
True
)
params
,
override_dict
,
is_strict
=
True
)
...
@@ -215,7 +258,12 @@ class ParamsDictIOTest(tf.test.TestCase):
...
@@ -215,7 +258,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_yaml_string
(
self
):
def
test_override_params_dict_using_yaml_string
(
self
):
params
=
params_dict
.
ParamsDict
({
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_yaml_string
=
"'b': 5.2
\n
'c': [30, 40]"
override_yaml_string
=
"'b': 5.2
\n
'c': [30, 40]"
params
=
params_dict
.
override_params_dict
(
params
=
params_dict
.
override_params_dict
(
params
,
override_yaml_string
,
is_strict
=
True
)
params
,
override_yaml_string
,
is_strict
=
True
)
...
@@ -227,8 +275,18 @@ class ParamsDictIOTest(tf.test.TestCase):
...
@@ -227,8 +275,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_json_string
(
self
):
def
test_override_params_dict_using_json_string
(
self
):
params
=
params_dict
.
ParamsDict
({
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],},
'a'
:
1
,
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}},
'e'
:
False
})
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],
},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}
},
'e'
:
False
})
override_json_string
=
"{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
override_json_string
=
"{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
params
=
params_dict
.
override_params_dict
(
params
=
params_dict
.
override_params_dict
(
params
,
override_json_string
,
is_strict
=
True
)
params
,
override_json_string
,
is_strict
=
True
)
...
@@ -240,8 +298,18 @@ class ParamsDictIOTest(tf.test.TestCase):
...
@@ -240,8 +298,18 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_csv_string
(
self
):
def
test_override_params_dict_using_csv_string
(
self
):
params
=
params_dict
.
ParamsDict
({
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],},
'a'
:
1
,
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}},
'e'
:
False
})
'b'
:
{
'b1'
:
2
,
'b2'
:
[
2
,
3
],
},
'd'
:
{
'd1'
:
{
'd2'
:
'hello'
}
},
'e'
:
False
})
override_csv_string
=
"b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
override_csv_string
=
"b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
params
=
params_dict
.
override_params_dict
(
params
=
params_dict
.
override_params_dict
(
params
,
override_csv_string
,
is_strict
=
True
)
params
,
override_csv_string
,
is_strict
=
True
)
...
@@ -253,7 +321,12 @@ class ParamsDictIOTest(tf.test.TestCase):
...
@@ -253,7 +321,12 @@ class ParamsDictIOTest(tf.test.TestCase):
def
test_override_params_dict_using_yaml_file
(
self
):
def
test_override_params_dict_using_yaml_file
(
self
):
params
=
params_dict
.
ParamsDict
({
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
'a'
:
1
,
'b'
:
2.5
,
'c'
:
[
3
,
4
],
'd'
:
'hello'
,
'e'
:
False
})
override_yaml_file
=
self
.
write_temp_file
(
override_yaml_file
=
self
.
write_temp_file
(
'params.yaml'
,
r
"""
'params.yaml'
,
r
"""
b: 5.2
b: 5.2
...
@@ -321,8 +394,7 @@ class IOTest(tf.test.TestCase):
...
@@ -321,8 +394,7 @@ class IOTest(tf.test.TestCase):
def
test_csv_str_load_unsupported_datatypes
(
self
):
def
test_csv_str_load_unsupported_datatypes
(
self
):
csv_str
=
'a=[[1,2,3],[4,5,6]]'
csv_str
=
'a=[[1,2,3],[4,5,6]]'
self
.
assertRaises
(
ValueError
,
self
.
assertRaises
(
ValueError
,
params_dict
.
nested_csv_str_to_json_str
,
params_dict
.
nested_csv_str_to_json_str
,
csv_str
)
csv_str
)
def
test_csv_str_to_json_str_spacing
(
self
):
def
test_csv_str_to_json_str_spacing
(
self
):
...
...
official/modeling/optimization/configs/learning_rate_config.py
View file @
999fae62
...
@@ -50,16 +50,13 @@ class StepwiseLrConfig(base_config.Config):
...
@@ -50,16 +50,13 @@ class StepwiseLrConfig(base_config.Config):
Attributes:
Attributes:
name: The name of the learning rate schedule. Defaults to PiecewiseConstant.
name: The name of the learning rate schedule. Defaults to PiecewiseConstant.
boundaries: A list of ints of strictly increasing entries.
boundaries: A list of ints of strictly increasing entries. Defaults to None.
Defaults to None.
values: A list of floats that specifies the values for the intervals defined
values: A list of floats that specifies the values for the intervals defined
by `boundaries`. It should have one more element than `boundaries`.
by `boundaries`. It should have one more element than `boundaries`.
The learning rate is computed as follows:
The learning rate is computed as follows: [0, boundaries[0]] ->
[0, boundaries[0]] -> values[0]
values[0] [boundaries[0], boundaries[1]] -> values[1]
[boundaries[0], boundaries[1]] -> values[1]
[boundaries[n-1], boundaries[n]] -> values[n] [boundaries[n],
[boundaries[n-1], boundaries[n]] -> values[n]
end] -> values[n+1] Defaults to None.
[boundaries[n], end] -> values[n+1]
Defaults to None.
"""
"""
name
:
str
=
'PiecewiseConstantDecay'
name
:
str
=
'PiecewiseConstantDecay'
boundaries
:
Optional
[
List
[
int
]]
=
None
boundaries
:
Optional
[
List
[
int
]]
=
None
...
@@ -74,13 +71,12 @@ class ExponentialLrConfig(base_config.Config):
...
@@ -74,13 +71,12 @@ class ExponentialLrConfig(base_config.Config):
Attributes:
Attributes:
name: The name of the learning rate schedule. Defaults to ExponentialDecay.
name: The name of the learning rate schedule. Defaults to ExponentialDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to
initial_learning_rate: A float. The initial learning rate. Defaults to None.
None.
decay_steps: A positive integer that is used for decay computation. Defaults
decay_steps: A positive integer that is used for decay computation.
to None.
Defaults to None.
decay_rate: A float. Defaults to None.
decay_rate: A float. Defaults to None.
staircase: A boolean, if true, learning rate is decreased at discreate
staircase: A boolean, if true, learning rate is decreased at discreate
intervals. Defaults to False.
intervals. Defaults to False.
"""
"""
name
:
str
=
'ExponentialDecay'
name
:
str
=
'ExponentialDecay'
initial_learning_rate
:
Optional
[
float
]
=
None
initial_learning_rate
:
Optional
[
float
]
=
None
...
@@ -97,14 +93,13 @@ class PolynomialLrConfig(base_config.Config):
...
@@ -97,14 +93,13 @@ class PolynomialLrConfig(base_config.Config):
Attributes:
Attributes:
name: The name of the learning rate schedule. Defaults to PolynomialDecay.
name: The name of the learning rate schedule. Defaults to PolynomialDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to
initial_learning_rate: A float. The initial learning rate. Defaults to None.
None.
decay_steps: A positive integer that is used for decay computation. Defaults
decay_steps: A positive integer that is used for decay computation.
to None.
Defaults to None.
end_learning_rate: A float. The minimal end learning rate.
end_learning_rate: A float. The minimal end learning rate.
power: A float. The power of the polynomial. Defaults to linear, 1.0.
power: A float. The power of the polynomial. Defaults to linear, 1.0.
cycle: A boolean, whether or not it should cycle beyond decay_steps.
cycle: A boolean, whether or not it should cycle beyond decay_steps.
Defaults to False.
Defaults to False.
"""
"""
name
:
str
=
'PolynomialDecay'
name
:
str
=
'PolynomialDecay'
initial_learning_rate
:
Optional
[
float
]
=
None
initial_learning_rate
:
Optional
[
float
]
=
None
...
@@ -123,12 +118,11 @@ class CosineLrConfig(base_config.Config):
...
@@ -123,12 +118,11 @@ class CosineLrConfig(base_config.Config):
Attributes:
Attributes:
name: The name of the learning rate schedule. Defaults to CosineDecay.
name: The name of the learning rate schedule. Defaults to CosineDecay.
initial_learning_rate: A float. The initial learning rate. Defaults to
initial_learning_rate: A float. The initial learning rate. Defaults to None.
None.
decay_steps: A positive integer that is used for decay computation. Defaults
decay_steps: A positive integer that is used for decay computation.
to None.
Defaults to None.
alpha: A float. Minimum learning rate value as a fraction of
alpha: A float. Minimum learning rate value as a fraction of
initial_learning_rate.
initial_learning_rate.
"""
"""
name
:
str
=
'CosineDecay'
name
:
str
=
'CosineDecay'
initial_learning_rate
:
Optional
[
float
]
=
None
initial_learning_rate
:
Optional
[
float
]
=
None
...
@@ -173,4 +167,3 @@ class PolynomialWarmupConfig(base_config.Config):
...
@@ -173,4 +167,3 @@ class PolynomialWarmupConfig(base_config.Config):
name
:
str
=
'polynomial'
name
:
str
=
'polynomial'
power
:
float
=
1
power
:
float
=
1
warmup_steps
:
Optional
[
int
]
=
None
warmup_steps
:
Optional
[
int
]
=
None
official/modeling/optimization/configs/optimization_config_test.py
View file @
999fae62
...
@@ -50,12 +50,11 @@ class OptimizerConfigTest(tf.test.TestCase):
...
@@ -50,12 +50,11 @@ class OptimizerConfigTest(tf.test.TestCase):
'type'
:
'linear'
'type'
:
'linear'
}
}
})
})
self
.
assertEqual
(
opt_config
.
optimizer
.
get
(),
self
.
assertEqual
(
opt_config
.
optimizer
.
get
(),
opt_cfg
.
SGDConfig
())
opt_cfg
.
SGDConfig
())
self
.
assertEqual
(
opt_config
.
learning_rate
.
get
(),
self
.
assertEqual
(
opt_config
.
learning_rate
.
get
(),
lr_cfg
.
PolynomialLrConfig
())
lr_cfg
.
PolynomialLrConfig
())
self
.
assertEqual
(
opt_config
.
warmup
.
get
(),
self
.
assertEqual
(
opt_config
.
warmup
.
get
(),
lr_cfg
.
LinearWarmupConfig
())
lr_cfg
.
LinearWarmupConfig
())
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/modeling/optimization/configs/optimizer_config.py
View file @
999fae62
...
@@ -72,7 +72,7 @@ class AdamConfig(base_config.Config):
...
@@ -72,7 +72,7 @@ class AdamConfig(base_config.Config):
beta_2: decay rate for 2st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in Adam optimizer.
epsilon: epsilon value used for numerical stability in Adam optimizer.
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
the paper "On the Convergence of Adam and beyond".
the paper "On the Convergence of Adam and beyond".
"""
"""
name
:
str
=
"Adam"
name
:
str
=
"Adam"
beta_1
:
float
=
0.9
beta_1
:
float
=
0.9
...
@@ -91,12 +91,12 @@ class AdamWeightDecayConfig(base_config.Config):
...
@@ -91,12 +91,12 @@ class AdamWeightDecayConfig(base_config.Config):
beta_2: decay rate for 2st order moments.
beta_2: decay rate for 2st order moments.
epsilon: epsilon value used for numerical stability in the optimizer.
epsilon: epsilon value used for numerical stability in the optimizer.
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
the paper "On the Convergence of Adam and beyond".
the paper "On the Convergence of Adam and beyond".
weight_decay_rate: float. Weight decay rate. Default to 0.
weight_decay_rate: float. Weight decay rate. Default to 0.
include_in_weight_decay: list[str], or None. List of weight names to include
include_in_weight_decay: list[str], or None. List of weight names to include
in weight decay.
in weight decay.
include_in_weight_decay: list[str], or None. List of weight names to not
include_in_weight_decay: list[str], or None. List of weight names to not
include in weight decay.
include in weight decay.
"""
"""
name
:
str
=
"AdamWeightDecay"
name
:
str
=
"AdamWeightDecay"
beta_1
:
float
=
0.9
beta_1
:
float
=
0.9
...
@@ -123,12 +123,11 @@ class LAMBConfig(base_config.Config):
...
@@ -123,12 +123,11 @@ class LAMBConfig(base_config.Config):
epsilon: epsilon value used for numerical stability in LAMB optimizer.
epsilon: epsilon value used for numerical stability in LAMB optimizer.
weight_decay_rate: float. Weight decay rate. Default to 0.
weight_decay_rate: float. Weight decay rate. Default to 0.
exclude_from_weight_decay: List of regex patterns of variables excluded from
exclude_from_weight_decay: List of regex patterns of variables excluded from
weight decay. Variables whose name contain a
weight decay. Variables whose name contain a
substring matching the
substring matching the
pattern will be excluded.
pattern will be excluded.
exclude_from_layer_adaptation: List of regex patterns of variables excluded
exclude_from_layer_adaptation: List of regex patterns of variables excluded
from layer adaptation. Variables whose name
from layer adaptation. Variables whose name contain a substring matching
contain a substring matching the pattern will
the pattern will be excluded.
be excluded.
"""
"""
name
:
str
=
"LAMB"
name
:
str
=
"LAMB"
beta_1
:
float
=
0.9
beta_1
:
float
=
0.9
...
...
official/modeling/optimization/optimizer_factory.py
View file @
999fae62
...
@@ -131,8 +131,9 @@ class OptimizerFactory(object):
...
@@ -131,8 +131,9 @@ class OptimizerFactory(object):
rate built using self.build_lr() is passed as an argument to this method.
rate built using self.build_lr() is passed as an argument to this method.
Args:
Args:
lr: A floating point value, or
lr: A floating point value, or a
a tf.keras.optimizers.schedules.LearningRateSchedule instance.
tf.keras.optimizers.schedules.LearningRateSchedule instance.
Returns:
Returns:
tf.keras.optimizers.Optimizer instance.
tf.keras.optimizers.Optimizer instance.
"""
"""
...
@@ -142,4 +143,3 @@ class OptimizerFactory(object):
...
@@ -142,4 +143,3 @@ class OptimizerFactory(object):
optimizer
=
OPTIMIZERS_CLS
[
self
.
_optimizer_type
](
**
optimizer_dict
)
optimizer
=
OPTIMIZERS_CLS
[
self
.
_optimizer_type
](
**
optimizer_dict
)
return
optimizer
return
optimizer
official/modeling/optimization/optimizer_factory_test.py
View file @
999fae62
...
@@ -25,12 +25,7 @@ from official.modeling.optimization.configs import optimization_config
...
@@ -25,12 +25,7 @@ from official.modeling.optimization.configs import optimization_config
class
OptimizerFactoryTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
OptimizerFactoryTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
@
parameterized
.
parameters
((
'sgd'
),
(
'rmsprop'
),
(
'adam'
),
(
'adamw'
),
(
'lamb'
))
(
'sgd'
),
(
'rmsprop'
),
(
'adam'
),
(
'adamw'
),
(
'lamb'
))
def
test_optimizers
(
self
,
optimizer_type
):
def
test_optimizers
(
self
,
optimizer_type
):
params
=
{
params
=
{
'optimizer'
:
{
'optimizer'
:
{
...
@@ -56,20 +51,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -56,20 +51,17 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertEqual
(
expected_optimizer_config
,
optimizer
.
get_config
())
self
.
assertEqual
(
expected_optimizer_config
,
optimizer
.
get_config
())
def
test_missing_types
(
self
):
def
test_missing_types
(
self
):
params
=
{
params
=
{
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}}}
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
}
}
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
optimizer_factory
.
OptimizerFactory
(
optimizer_factory
.
OptimizerFactory
(
optimization_config
.
OptimizationConfig
(
params
))
optimization_config
.
OptimizationConfig
(
params
))
params
=
{
params
=
{
'learning_rate'
:
{
'learning_rate'
:
{
'type'
:
'stepwise'
,
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
10000
,
20000
],
'stepwise'
:
{
'values'
:
[
0.1
,
0.01
,
0.001
]}
'boundaries'
:
[
10000
,
20000
],
'values'
:
[
0.1
,
0.01
,
0.001
]
}
}
}
}
}
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
...
@@ -80,22 +72,20 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -80,22 +72,20 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
params
=
{
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'sgd'
,
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
},
'learning_rate'
:
{
'learning_rate'
:
{
'type'
:
'stepwise'
,
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
10000
,
20000
],
'stepwise'
:
{
'values'
:
[
0.1
,
0.01
,
0.001
]}
'boundaries'
:
[
10000
,
20000
],
'values'
:
[
0.1
,
0.01
,
0.001
]
}
}
}
}
}
expected_lr_step_values
=
[
expected_lr_step_values
=
[[
0
,
0.1
],
[
5000
,
0.1
],
[
10000
,
0.1
],
[
0
,
0.1
],
[
10001
,
0.01
],
[
20000
,
0.01
],
[
20001
,
0.001
]]
[
5000
,
0.1
],
[
10000
,
0.1
],
[
10001
,
0.01
],
[
20000
,
0.01
],
[
20001
,
0.001
]
]
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
lr
=
opt_factory
.
build_learning_rate
()
...
@@ -107,28 +97,28 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -107,28 +97,28 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
params
=
{
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'sgd'
,
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
},
'learning_rate'
:
{
'learning_rate'
:
{
'type'
:
'stepwise'
,
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
10000
,
20000
],
'stepwise'
:
{
'values'
:
[
0.1
,
0.01
,
0.001
]}
'boundaries'
:
[
10000
,
20000
],
'values'
:
[
0.1
,
0.01
,
0.001
]
}
},
},
'warmup'
:
{
'warmup'
:
{
'type'
:
'linear'
,
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
500
,
'warmup_learning_rate'
:
0.01
}
'linear'
:
{
'warmup_steps'
:
500
,
'warmup_learning_rate'
:
0.01
}
}
}
}
}
expected_lr_step_values
=
[
expected_lr_step_values
=
[[
0
,
0.01
],
[
250
,
0.055
],
[
500
,
0.1
],
[
5500
,
0.1
],
[
0
,
0.01
],
[
10000
,
0.1
],
[
10001
,
0.01
],
[
20000
,
0.01
],
[
250
,
0.055
],
[
20001
,
0.001
]]
[
500
,
0.1
],
[
5500
,
0.1
],
[
10000
,
0.1
],
[
10001
,
0.01
],
[
20000
,
0.01
],
[
20001
,
0.001
]
]
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
lr
=
opt_factory
.
build_learning_rate
()
...
@@ -140,7 +130,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -140,7 +130,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
params
=
{
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'sgd'
,
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
},
'learning_rate'
:
{
'learning_rate'
:
{
'type'
:
'exponential'
,
'type'
:
'exponential'
,
...
@@ -170,7 +162,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -170,7 +162,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
params
=
{
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'sgd'
,
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
},
'learning_rate'
:
{
'learning_rate'
:
{
'type'
:
'polynomial'
,
'type'
:
'polynomial'
,
...
@@ -194,7 +188,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -194,7 +188,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
params
=
{
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'sgd'
,
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
},
'learning_rate'
:
{
'learning_rate'
:
{
'type'
:
'cosine'
,
'type'
:
'cosine'
,
...
@@ -204,11 +200,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -204,11 +200,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
}
}
}
}
}
}
expected_lr_step_values
=
[[
0
,
0.1
],
expected_lr_step_values
=
[[
0
,
0.1
],
[
250
,
0.08535534
],
[
500
,
0.04999999
],
[
250
,
0.08535534
],
[
750
,
0.01464466
],
[
1000
,
0
]]
[
500
,
0.04999999
],
[
750
,
0.01464466
],
[
1000
,
0
]]
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
lr
=
opt_factory
.
build_learning_rate
()
...
@@ -220,7 +213,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -220,7 +213,9 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
params
=
{
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'sgd'
,
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
},
'learning_rate'
:
{
'learning_rate'
:
{
'type'
:
'constant'
,
'type'
:
'constant'
,
...
@@ -250,28 +245,28 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -250,28 +245,28 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
params
=
{
params
=
{
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'sgd'
,
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
}
'sgd'
:
{
'momentum'
:
0.9
}
},
},
'learning_rate'
:
{
'learning_rate'
:
{
'type'
:
'stepwise'
,
'type'
:
'stepwise'
,
'stepwise'
:
{
'boundaries'
:
[
10000
,
20000
],
'stepwise'
:
{
'values'
:
[
0.1
,
0.01
,
0.001
]}
'boundaries'
:
[
10000
,
20000
],
'values'
:
[
0.1
,
0.01
,
0.001
]
}
},
},
'warmup'
:
{
'warmup'
:
{
'type'
:
'polynomial'
,
'type'
:
'polynomial'
,
'polynomial'
:
{
'warmup_steps'
:
500
,
'power'
:
2.
}
'polynomial'
:
{
'warmup_steps'
:
500
,
'power'
:
2.
}
}
}
}
}
expected_lr_step_values
=
[
expected_lr_step_values
=
[[
0
,
0.0
],
[
250
,
0.025
],
[
500
,
0.1
],
[
5500
,
0.1
],
[
0
,
0.0
],
[
10000
,
0.1
],
[
10001
,
0.01
],
[
20000
,
0.01
],
[
250
,
0.025
],
[
20001
,
0.001
]]
[
500
,
0.1
],
[
5500
,
0.1
],
[
10000
,
0.1
],
[
10001
,
0.01
],
[
20000
,
0.01
],
[
20001
,
0.001
]
]
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_config
=
optimization_config
.
OptimizationConfig
(
params
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
opt_factory
=
optimizer_factory
.
OptimizerFactory
(
opt_config
)
lr
=
opt_factory
.
build_learning_rate
()
lr
=
opt_factory
.
build_learning_rate
()
...
@@ -279,5 +274,6 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -279,5 +274,6 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
for
step
,
value
in
expected_lr_step_values
:
for
step
,
value
in
expected_lr_step_values
:
self
.
assertAlmostEqual
(
lr
(
step
).
numpy
(),
value
)
self
.
assertAlmostEqual
(
lr
(
step
).
numpy
(),
value
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/modeling/performance.py
View file @
999fae62
...
@@ -21,7 +21,7 @@ import tensorflow as tf
...
@@ -21,7 +21,7 @@ import tensorflow as tf
def
configure_optimizer
(
optimizer
,
def
configure_optimizer
(
optimizer
,
use_float16
=
False
,
use_float16
=
False
,
use_graph_rewrite
=
False
,
use_graph_rewrite
=
False
,
loss_scale
=
"
dynamic
"
):
loss_scale
=
'
dynamic
'
):
"""Configures optimizer object with performance options."""
"""Configures optimizer object with performance options."""
if
use_float16
:
if
use_float16
:
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
# Wraps optimizer with a LossScaleOptimizer. This is done automatically
...
@@ -47,10 +47,9 @@ def set_mixed_precision_policy(dtype, loss_scale=None):
...
@@ -47,10 +47,9 @@ def set_mixed_precision_policy(dtype, loss_scale=None):
'mixed_float16'
,
loss_scale
=
loss_scale
)
'mixed_float16'
,
loss_scale
=
loss_scale
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
elif
dtype
==
tf
.
bfloat16
:
elif
dtype
==
tf
.
bfloat16
:
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
policy
=
tf
.
keras
.
mixed_precision
.
experimental
.
Policy
(
'mixed_bfloat16'
)
'mixed_bfloat16'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
policy
)
elif
dtype
==
tf
.
float32
:
elif
dtype
==
tf
.
float32
:
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
tf
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
'float32'
)
else
:
else
:
raise
ValueError
(
"
Unexpected dtype: %s
"
%
dtype
)
raise
ValueError
(
'
Unexpected dtype: %s
'
%
dtype
)
official/modeling/tf_utils.py
View file @
999fae62
...
@@ -29,8 +29,7 @@ from official.modeling import activations
...
@@ -29,8 +29,7 @@ from official.modeling import activations
None
,
None
,
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
"input tensors. pack/unpack inputs to override __call__ is no longer "
"input tensors. pack/unpack inputs to override __call__ is no longer "
"needed."
"needed."
)
)
def
pack_inputs
(
inputs
):
def
pack_inputs
(
inputs
):
"""Pack a list of `inputs` tensors to a tuple.
"""Pack a list of `inputs` tensors to a tuple.
...
@@ -55,8 +54,7 @@ def pack_inputs(inputs):
...
@@ -55,8 +54,7 @@ def pack_inputs(inputs):
None
,
None
,
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
"tf.keras.layers.Layer supports multiple positional args and kwargs as "
"input tensors. pack/unpack inputs to override __call__ is no longer "
"input tensors. pack/unpack inputs to override __call__ is no longer "
"needed."
"needed."
)
)
def
unpack_inputs
(
inputs
):
def
unpack_inputs
(
inputs
):
"""unpack a tuple of `inputs` tensors to a tuple.
"""unpack a tuple of `inputs` tensors to a tuple.
...
...
official/modeling/training/distributed_executor.py
View file @
999fae62
...
@@ -133,15 +133,9 @@ class SummaryWriter(object):
...
@@ -133,15 +133,9 @@ class SummaryWriter(object):
class
DistributedExecutor
(
object
):
class
DistributedExecutor
(
object
):
"""Interface to train and eval models with tf.distribute.Strategy.
"""Interface to train and eval models with tf.distribute.Strategy."""
"""
def
__init__
(
self
,
def
__init__
(
self
,
strategy
,
params
,
model_fn
,
loss_fn
,
is_multi_host
=
False
):
strategy
,
params
,
model_fn
,
loss_fn
,
is_multi_host
=
False
):
"""Constructor.
"""Constructor.
Args:
Args:
...
@@ -293,8 +287,7 @@ class DistributedExecutor(object):
...
@@ -293,8 +287,7 @@ class DistributedExecutor(object):
raise
ValueError
(
'steps should be an Tensor. Python object may cause '
raise
ValueError
(
'steps should be an Tensor. Python object may cause '
'retracing.'
)
'retracing.'
)
per_replica_losses
=
strategy
.
run
(
per_replica_losses
=
strategy
.
run
(
replicated_step
,
args
=
(
next
(
iterator
),))
replicated_step
,
args
=
(
next
(
iterator
),))
for
_
in
tf
.
range
(
num_steps
-
1
):
for
_
in
tf
.
range
(
num_steps
-
1
):
per_replica_losses
=
strategy
.
run
(
per_replica_losses
=
strategy
.
run
(
replicated_step
,
args
=
(
next
(
iterator
),))
replicated_step
,
args
=
(
next
(
iterator
),))
...
@@ -368,6 +361,7 @@ class DistributedExecutor(object):
...
@@ -368,6 +361,7 @@ class DistributedExecutor(object):
available checkpoints. If `False`, will do the evaluation once after the
available checkpoints. If `False`, will do the evaluation once after the
final step.
final step.
save_config: bool. Whether to save params to model_dir.
save_config: bool. Whether to save params to model_dir.
Returns:
Returns:
The training loss and eval metrics.
The training loss and eval metrics.
"""
"""
...
@@ -477,16 +471,15 @@ class DistributedExecutor(object):
...
@@ -477,16 +471,15 @@ class DistributedExecutor(object):
# Step-0 operations
# Step-0 operations
if
current_step
==
0
and
not
latest_checkpoint_file
:
if
current_step
==
0
and
not
latest_checkpoint_file
:
_save_checkpoint
(
_save_checkpoint
(
checkpoint
,
model_dir
,
checkpoint
,
model_dir
,
checkpoint_name
.
format
(
step
=
current_step
))
checkpoint_name
.
format
(
step
=
current_step
))
if
test_step
:
if
test_step
:
eval_iterator
=
self
.
_get_input_iterator
(
eval_input_fn
,
strategy
)
eval_iterator
=
self
.
_get_input_iterator
(
eval_input_fn
,
strategy
)
eval_metric_result
=
self
.
_run_evaluation
(
eval_metric_result
=
self
.
_run_evaluation
(
test_step
,
current_step
,
test_step
,
current_step
,
eval_metric
,
eval_iterator
)
eval_metric
,
eval_iterator
)
logging
.
info
(
logging
.
info
(
'Step: %s evalation metric = %s.'
,
current_step
,
'Step: %s evalation metric = %s.'
,
current_step
,
eval_metric_result
)
eval_metric_result
)
test_summary_writer
(
test_summary_writer
(
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
reset_states
(
eval_metric
)
reset_states
(
eval_metric
)
logging
.
info
(
'Training started'
)
logging
.
info
(
'Training started'
)
...
@@ -519,8 +512,7 @@ class DistributedExecutor(object):
...
@@ -519,8 +512,7 @@ class DistributedExecutor(object):
else
:
else
:
train_metric_result
.
update
({
'learning_rate'
:
optimizer
.
lr
.
numpy
()})
train_metric_result
.
update
({
'learning_rate'
:
optimizer
.
lr
.
numpy
()})
logging
.
info
(
'Train Step: %d/%d / loss = %s / training metric = %s'
,
logging
.
info
(
'Train Step: %d/%d / loss = %s / training metric = %s'
,
current_step
,
total_steps
,
train_loss
,
current_step
,
total_steps
,
train_loss
,
train_metric_result
)
train_metric_result
)
train_summary_writer
(
train_summary_writer
(
metrics
=
train_metric_result
,
step
=
optimizer
.
iterations
)
metrics
=
train_metric_result
,
step
=
optimizer
.
iterations
)
...
@@ -561,8 +553,7 @@ class DistributedExecutor(object):
...
@@ -561,8 +553,7 @@ class DistributedExecutor(object):
eval_metric_result
=
self
.
_run_evaluation
(
test_step
,
current_step
,
eval_metric_result
=
self
.
_run_evaluation
(
test_step
,
current_step
,
eval_metric
,
eval_iterator
)
eval_metric
,
eval_iterator
)
logging
.
info
(
'Final evaluation metric = %s.'
,
eval_metric_result
)
logging
.
info
(
'Final evaluation metric = %s.'
,
eval_metric_result
)
test_summary_writer
(
test_summary_writer
(
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
metrics
=
eval_metric_result
,
step
=
optimizer
.
iterations
)
self
.
train_summary_writer
.
close
()
self
.
train_summary_writer
.
close
()
self
.
eval_summary_writer
.
close
()
self
.
eval_summary_writer
.
close
()
...
@@ -696,9 +687,8 @@ class DistributedExecutor(object):
...
@@ -696,9 +687,8 @@ class DistributedExecutor(object):
reader
=
tf
.
compat
.
v1
.
train
.
NewCheckpointReader
(
checkpoint_path
)
reader
=
tf
.
compat
.
v1
.
train
.
NewCheckpointReader
(
checkpoint_path
)
current_step
=
reader
.
get_tensor
(
current_step
=
reader
.
get_tensor
(
'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE'
)
'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE'
)
logging
.
info
(
logging
.
info
(
'Checkpoint file %s found and restoring from '
'Checkpoint file %s found and restoring from '
'checkpoint'
,
checkpoint_path
)
'checkpoint'
,
checkpoint_path
)
status
=
checkpoint
.
restore
(
checkpoint_path
)
status
=
checkpoint
.
restore
(
checkpoint_path
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
...
@@ -755,8 +745,8 @@ class ExecutorBuilder(object):
...
@@ -755,8 +745,8 @@ class ExecutorBuilder(object):
"""
"""
def
__init__
(
self
,
strategy_type
=
None
,
strategy_config
=
None
):
def
__init__
(
self
,
strategy_type
=
None
,
strategy_config
=
None
):
_
=
distribution_utils
.
configure_cluster
(
_
=
distribution_utils
.
configure_cluster
(
strategy_config
.
worker_hosts
,
strategy_config
.
worker_hosts
,
strategy_config
.
task_index
)
strategy_config
.
task_index
)
"""Constructor.
"""Constructor.
Args:
Args:
...
...
official/nlp/albert/configs.py
View file @
999fae62
...
@@ -26,10 +26,7 @@ from official.nlp.bert import configs
...
@@ -26,10 +26,7 @@ from official.nlp.bert import configs
class
AlbertConfig
(
configs
.
BertConfig
):
class
AlbertConfig
(
configs
.
BertConfig
):
"""Configuration for `ALBERT`."""
"""Configuration for `ALBERT`."""
def
__init__
(
self
,
def
__init__
(
self
,
num_hidden_groups
=
1
,
inner_group_num
=
1
,
**
kwargs
):
num_hidden_groups
=
1
,
inner_group_num
=
1
,
**
kwargs
):
"""Constructs AlbertConfig.
"""Constructs AlbertConfig.
Args:
Args:
...
...
official/nlp/albert/export_albert_tfhub.py
View file @
999fae62
...
@@ -18,6 +18,7 @@ from __future__ import division
...
@@ -18,6 +18,7 @@ from __future__ import division
# from __future__ import google_type_annotations
# from __future__ import google_type_annotations
from
__future__
import
print_function
from
__future__
import
print_function
# Import libraries
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
official/nlp/albert/run_classifier.py
View file @
999fae62
...
@@ -21,6 +21,7 @@ from __future__ import print_function
...
@@ -21,6 +21,7 @@ from __future__ import print_function
import
json
import
json
import
os
import
os
# Import libraries
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
...
...
Prev
1
2
3
4
5
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment