Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
40e12432
"vscode:/vscode.git/clone" did not exist on "6db661300f472ee8852882af2c3b8b182a403ed1"
Commit
40e12432
authored
Aug 19, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 327459481
parent
30821184
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
189 additions
and
8 deletions
+189
-8
official/core/base_trainer.py
official/core/base_trainer.py
+14
-1
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+26
-0
official/core/train_lib.py
official/core/train_lib.py
+119
-1
official/core/train_utils.py
official/core/train_utils.py
+15
-6
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+15
-0
No files found.
official/core/base_trainer.py
View file @
40e12432
...
@@ -42,7 +42,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
...
@@ -42,7 +42,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
train
:
bool
=
True
,
train
:
bool
=
True
,
evaluate
:
bool
=
True
,
evaluate
:
bool
=
True
,
model
=
None
,
model
=
None
,
optimizer
=
None
):
optimizer
=
None
,
checkpoint_exporter
=
None
):
"""Initialize common trainer for TensorFlow models.
"""Initialize common trainer for TensorFlow models.
Args:
Args:
...
@@ -56,6 +57,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
...
@@ -56,6 +57,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
building model using task.build_model(). Default to None.
building model using task.build_model(). Default to None.
optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will
optimizer: tf.keras.optimizers.Optimizer instance. If provided, it will
used instead of the optimizer from config. Default to None.
used instead of the optimizer from config. Default to None.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
"""
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
# it gets a single-replica no-op strategy.
...
@@ -73,6 +76,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
...
@@ -73,6 +76,8 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
else
:
else
:
self
.
_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
self
.
_checkpoint_exporter
=
checkpoint_exporter
# Configuring optimizer when loss_scale is set in runtime config. This helps
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
# avoiding overflow/underflow for float16 computations.
if
config
.
runtime
.
loss_scale
:
if
config
.
runtime
.
loss_scale
:
...
@@ -235,6 +240,14 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
...
@@ -235,6 +240,14 @@ class Trainer(orbit.StandardTrainer, orbit.StandardEvaluator):
if
aggregated_logs
:
if
aggregated_logs
:
metrics
=
self
.
task
.
reduce_aggregated_logs
(
aggregated_logs
)
metrics
=
self
.
task
.
reduce_aggregated_logs
(
aggregated_logs
)
logs
.
update
(
metrics
)
logs
.
update
(
metrics
)
if
self
.
_checkpoint_exporter
:
self
.
_checkpoint_exporter
.
maybe_export_checkpoint
(
self
.
checkpoint
,
logs
,
self
.
global_step
.
numpy
())
metric_name
=
self
.
config
.
trainer
.
best_checkpoint_eval_metric
logs
[
'best_'
+
metric_name
]
=
self
.
_checkpoint_exporter
.
best_ckpt_logs
[
metric_name
]
return
logs
return
logs
def
eval_reduce
(
self
,
state
=
None
,
step_outputs
=
None
):
def
eval_reduce
(
self
,
state
=
None
,
step_outputs
=
None
):
...
...
official/core/base_trainer_test.py
View file @
40e12432
...
@@ -16,12 +16,14 @@
...
@@ -16,12 +16,14 @@
"""Tests for tensorflow_models.core.trainers.trainer."""
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
# pylint: disable=g-direct-tensorflow-import
import
os
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
base_trainer
as
trainer_lib
from
official.core
import
base_trainer
as
trainer_lib
from
official.core
import
train_lib
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.utils.testing
import
mock_task
from
official.utils.testing
import
mock_task
...
@@ -105,6 +107,30 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -105,6 +107,30 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
metrics
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
metrics
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
metrics
)
self
.
assertIn
(
'training_loss'
,
metrics
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_export_best_ckpt
(
self
,
distribution
):
config
=
cfg
.
ExperimentConfig
(
trainer
=
cfg
.
TrainerConfig
(
best_checkpoint_export_subdir
=
'best_ckpt'
,
best_checkpoint_eval_metric
=
'acc'
,
optimizer_config
=
cfg
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
},
'learning_rate'
:
{
'type'
:
'constant'
}
})))
model_dir
=
self
.
get_temp_dir
()
task
=
mock_task
.
MockTask
(
config
.
task
,
logging_dir
=
model_dir
)
ckpt_exporter
=
train_lib
.
maybe_create_best_ckpt_exporter
(
config
,
model_dir
)
trainer
=
trainer_lib
.
Trainer
(
config
,
task
,
checkpoint_exporter
=
ckpt_exporter
)
trainer
.
train
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
trainer
.
evaluate
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
model_dir
,
'best_ckpt'
,
'info.json'
)))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/core/train_lib.py
View file @
40e12432
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
# ==============================================================================
# ==============================================================================
"""TFM common training driver library."""
"""TFM common training driver library."""
import
copy
import
json
import
os
import
os
from
typing
import
Any
,
Mapping
,
Tuple
from
typing
import
Any
,
Mapping
,
Tuple
...
@@ -28,6 +30,121 @@ from official.core import base_task
...
@@ -28,6 +30,121 @@ from official.core import base_task
from
official.modeling.hyperparams
import
config_definitions
from
official.modeling.hyperparams
import
config_definitions
class
BestCheckpointExporter
:
"""Keeps track of the best result, and saves its checkpoint.
Orbit will support an API for checkpoint exporter. This class will be used
together with orbit once this functionality is ready.
"""
def
__init__
(
self
,
export_dir
:
str
,
metric_name
:
str
,
metric_comp
:
str
):
"""Initialization.
Arguments:
export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which
result is better.
metric_comp: Indicates how to compare results. Either `lower` or `higher`.
"""
self
.
_export_dir
=
export_dir
self
.
_metric_name
=
metric_name
self
.
_metric_comp
=
metric_comp
if
self
.
_metric_comp
not
in
(
'lower'
,
'higher'
):
raise
ValueError
(
'best checkpoint metric comp must be one of '
'higher, lower. Got: {}'
.
format
(
self
.
_metric_comp
))
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
self
.
best_ckpt_logs_path
))
self
.
_best_ckpt_logs
=
self
.
_maybe_load_best_eval_metric
()
def
maybe_export_checkpoint
(
self
,
checkpoint
,
eval_logs
,
global_step
):
logging
.
info
(
'[BestCheckpointExporter] received eval_logs: %s, at step: %d'
,
eval_logs
,
global_step
)
if
self
.
_best_ckpt_logs
is
None
or
self
.
_new_metric_is_better
(
self
.
_best_ckpt_logs
,
eval_logs
):
self
.
_best_ckpt_logs
=
eval_logs
self
.
_export_best_eval_metric
(
checkpoint
,
self
.
_best_ckpt_logs
,
global_step
)
def
_maybe_load_best_eval_metric
(
self
):
if
not
tf
.
io
.
gfile
.
exists
(
self
.
best_ckpt_logs_path
):
return
None
with
tf
.
io
.
gfile
.
GFile
(
self
.
best_ckpt_logs_path
,
'r'
)
as
reader
:
return
json
.
loads
(
reader
.
read
())
def
_new_metric_is_better
(
self
,
old_logs
,
new_logs
):
"""Check if the metric in new_logs is better than the metric in old_logs."""
if
self
.
_metric_name
not
in
old_logs
or
self
.
_metric_name
not
in
new_logs
:
raise
KeyError
(
'best checkpoint eval metric name {} is not valid. '
'old_logs: {}, new_logs: {}'
.
format
(
self
.
_metric_name
,
old_logs
,
new_logs
))
old_value
=
float
(
orbit
.
utils
.
get_value
(
old_logs
[
self
.
_metric_name
]))
new_value
=
float
(
orbit
.
utils
.
get_value
(
new_logs
[
self
.
_metric_name
]))
logging
.
info
(
'[BestCheckpointExporter] comparing results. old: %f, new: %f'
,
old_value
,
new_value
)
if
self
.
_metric_comp
==
'higher'
:
if
new_value
>
old_value
:
logging
.
info
(
'[BestCheckpointExporter] '
'the new number is better since it is higher.'
)
return
True
else
:
# self._metric_comp == 'lower':
if
new_value
<
old_value
:
logging
.
info
(
'[BestCheckpointExporter] '
'the new number is better since it is lower.'
)
return
True
return
False
def
_export_best_eval_metric
(
self
,
checkpoint
,
eval_logs
,
global_step
):
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext
=
copy
.
copy
(
eval_logs
)
eval_logs_ext
[
'best_ckpt_global_step'
]
=
global_step
for
name
,
value
in
eval_logs_ext
.
items
():
eval_logs_ext
[
name
]
=
str
(
orbit
.
utils
.
get_value
(
value
))
# Saving json file is very fast.
with
tf
.
io
.
gfile
.
GFile
(
self
.
best_ckpt_logs_path
,
'w'
)
as
writer
:
writer
.
write
(
json
.
dumps
(
eval_logs_ext
,
indent
=
4
)
+
'
\n
'
)
# Saving the best checkpoint might be interrupted if the job got killed.
for
file_to_remove
in
tf
.
io
.
gfile
.
glob
(
self
.
best_ckpt_path
+
'*'
):
tf
.
io
.
gfile
.
rmtree
(
file_to_remove
)
checkpoint
.
save
(
self
.
best_ckpt_path
)
@
property
def
best_ckpt_logs
(
self
):
return
self
.
_best_ckpt_logs
@
property
def
best_ckpt_logs_path
(
self
):
return
os
.
path
.
join
(
self
.
_export_dir
,
'info.json'
)
@
property
def
best_ckpt_path
(
self
):
return
os
.
path
.
join
(
self
.
_export_dir
,
'best_ckpt'
)
def
maybe_create_best_ckpt_exporter
(
params
:
config_definitions
.
ExperimentConfig
,
data_dir
:
str
)
->
Any
:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir
=
params
.
trainer
.
best_checkpoint_export_subdir
metric_name
=
params
.
trainer
.
best_checkpoint_eval_metric
metric_comp
=
params
.
trainer
.
best_checkpoint_metric_comp
if
data_dir
and
export_subdir
and
metric_name
:
best_ckpt_dir
=
os
.
path
.
join
(
data_dir
,
export_subdir
)
best_ckpt_exporter
=
BestCheckpointExporter
(
best_ckpt_dir
,
metric_name
,
metric_comp
)
else
:
best_ckpt_exporter
=
None
logging
.
info
(
'Not exporting the best checkpoint. '
'data_dir: %s, export_subdir: %s, metric_name: %s'
,
data_dir
,
export_subdir
,
metric_name
)
return
best_ckpt_exporter
def
run_experiment
(
distribution_strategy
:
tf
.
distribute
.
Strategy
,
def
run_experiment
(
distribution_strategy
:
tf
.
distribute
.
Strategy
,
task
:
base_task
.
Task
,
task
:
base_task
.
Task
,
mode
:
str
,
mode
:
str
,
...
@@ -62,7 +179,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
...
@@ -62,7 +179,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
task
,
task
,
model_dir
,
model_dir
,
train
=
'train'
in
mode
,
train
=
'train'
in
mode
,
evaluate
=
(
'eval'
in
mode
)
or
run_post_eval
)
evaluate
=
(
'eval'
in
mode
)
or
run_post_eval
,
checkpoint_exporter
=
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
if
trainer
.
checkpoint
:
if
trainer
.
checkpoint
:
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
...
...
official/core/train_utils.py
View file @
40e12432
...
@@ -18,20 +18,32 @@
...
@@ -18,20 +18,32 @@
import
json
import
json
import
os
import
os
import
pprint
import
pprint
from
typing
import
Any
from
absl
import
logging
from
absl
import
logging
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
base_trainer
from
official.core
import
base_trainer
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.modeling.hyperparams
import
config_definitions
from
official.modeling.hyperparams
import
config_definitions
def
create_trainer
(
params
,
task
,
model_dir
,
train
,
evaluate
):
def
create_trainer
(
params
:
config_definitions
.
ExperimentConfig
,
task
:
base_task
.
Task
,
model_dir
:
str
,
train
:
bool
,
evaluate
:
bool
,
checkpoint_exporter
:
Any
=
None
):
"""Create trainer."""
del
model_dir
del
model_dir
logging
.
info
(
'Running default trainer.'
)
logging
.
info
(
'Running default trainer.'
)
trainer
=
base_trainer
.
Trainer
(
params
,
task
,
train
=
train
,
evaluate
=
evaluate
)
trainer
=
base_trainer
.
Trainer
(
params
,
task
,
train
=
train
,
evaluate
=
evaluate
,
checkpoint_exporter
=
checkpoint_exporter
)
return
trainer
return
trainer
...
@@ -122,10 +134,7 @@ def write_summary(summary_writer, global_step, eval_metrics):
...
@@ -122,10 +134,7 @@ def write_summary(summary_writer, global_step, eval_metrics):
"""Write evaluation metrics to TF summary."""
"""Write evaluation metrics to TF summary."""
numeric_dict
=
{}
numeric_dict
=
{}
for
name
,
value
in
eval_metrics
.
items
():
for
name
,
value
in
eval_metrics
.
items
():
if
hasattr
(
value
,
'numpy'
):
numeric_dict
[
name
]
=
float
(
orbit
.
utils
.
get_value
(
value
))
numeric_dict
[
name
]
=
value
.
numpy
().
astype
(
float
)
else
:
numeric_dict
[
name
]
=
value
with
summary_writer
.
as_default
():
with
summary_writer
.
as_default
():
for
name
,
value
in
numeric_dict
.
items
():
for
name
,
value
in
numeric_dict
.
items
():
tf
.
summary
.
scalar
(
name
,
value
,
step
=
global_step
)
tf
.
summary
.
scalar
(
name
,
value
,
step
=
global_step
)
...
...
official/modeling/hyperparams/config_definitions.py
View file @
40e12432
...
@@ -183,6 +183,17 @@ class TrainerConfig(base_config.Config):
...
@@ -183,6 +183,17 @@ class TrainerConfig(base_config.Config):
validation_steps: number of eval steps. If `None`, the entire eval dataset
validation_steps: number of eval steps. If `None`, the entire eval dataset
is used.
is used.
validation_interval: number of training steps to run between evaluations.
validation_interval: number of training steps to run between evaluations.
best_checkpoint_export_subdir: if set, the trainer will keep track of the
best evaluation metric, and export the corresponding best checkpoint under
`model_dir/best_checkpoint_export_subdir`. Note that this only works if
mode contains eval (such as `train_and_eval`, `continuous_eval`, and
`continuous_train_and_eval`).
best_checkpoint_eval_metric: for exporting the best checkpoint, which
evaluation metric the trainer should monitor. This can be any evaluation
metric appears on tensorboard.
best_checkpoint_metric_comp: for exporting the best checkpoint, how the
trainer should compare the evaluation metrics. This can be either `higher`
(higher the better) or `lower` (lower the better).
"""
"""
optimizer_config
:
OptimizationConfig
=
OptimizationConfig
()
optimizer_config
:
OptimizationConfig
=
OptimizationConfig
()
# Orbit settings.
# Orbit settings.
...
@@ -201,6 +212,10 @@ class TrainerConfig(base_config.Config):
...
@@ -201,6 +212,10 @@ class TrainerConfig(base_config.Config):
train_steps
:
int
=
0
train_steps
:
int
=
0
validation_steps
:
Optional
[
int
]
=
None
validation_steps
:
Optional
[
int
]
=
None
validation_interval
:
int
=
1000
validation_interval
:
int
=
1000
# Best checkpoint export.
best_checkpoint_export_subdir
:
str
=
""
best_checkpoint_eval_metric
:
str
=
""
best_checkpoint_metric_comp
:
str
=
"higher"
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment