Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
44cfd95e
Commit
44cfd95e
authored
Nov 17, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Nov 17, 2020
Browse files
Internal change
PiperOrigin-RevId: 342910130
parent
49f081ee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
101 additions
and
99 deletions
+101
-99
official/core/train_lib.py
official/core/train_lib.py
+3
-94
official/core/train_utils.py
official/core/train_utils.py
+98
-5
No files found.
official/core/train_lib.py
View file @
44cfd95e
...
...
@@ -15,8 +15,6 @@
# ==============================================================================
"""TFM common training driver library."""
# pytype: disable=attribute-error
import
copy
import
json
import
os
from
typing
import
Any
,
Mapping
,
Tuple
...
...
@@ -29,96 +27,7 @@ from official.core import base_task
from
official.core
import
config_definitions
from
official.core
import
train_utils
class
BestCheckpointExporter
:
"""Keeps track of the best result, and saves its checkpoint.
Orbit will support an API for checkpoint exporter. This class will be used
together with orbit once this functionality is ready.
"""
def
__init__
(
self
,
export_dir
:
str
,
metric_name
:
str
,
metric_comp
:
str
):
"""Initialization.
Arguments:
export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which
result is better.
metric_comp: Indicates how to compare results. Either `lower` or `higher`.
"""
self
.
_export_dir
=
export_dir
self
.
_metric_name
=
metric_name
self
.
_metric_comp
=
metric_comp
if
self
.
_metric_comp
not
in
(
'lower'
,
'higher'
):
raise
ValueError
(
'best checkpoint metric comp must be one of '
'higher, lower. Got: {}'
.
format
(
self
.
_metric_comp
))
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
self
.
best_ckpt_logs_path
))
self
.
_best_ckpt_logs
=
self
.
_maybe_load_best_eval_metric
()
def
maybe_export_checkpoint
(
self
,
checkpoint
,
eval_logs
,
global_step
):
logging
.
info
(
'[BestCheckpointExporter] received eval_logs: %s, at step: %d'
,
eval_logs
,
global_step
)
if
self
.
_best_ckpt_logs
is
None
or
self
.
_new_metric_is_better
(
self
.
_best_ckpt_logs
,
eval_logs
):
self
.
_best_ckpt_logs
=
eval_logs
self
.
_export_best_eval_metric
(
checkpoint
,
self
.
_best_ckpt_logs
,
global_step
)
def
_maybe_load_best_eval_metric
(
self
):
if
not
tf
.
io
.
gfile
.
exists
(
self
.
best_ckpt_logs_path
):
return
None
with
tf
.
io
.
gfile
.
GFile
(
self
.
best_ckpt_logs_path
,
'r'
)
as
reader
:
return
json
.
loads
(
reader
.
read
())
def
_new_metric_is_better
(
self
,
old_logs
,
new_logs
):
"""Check if the metric in new_logs is better than the metric in old_logs."""
if
self
.
_metric_name
not
in
old_logs
or
self
.
_metric_name
not
in
new_logs
:
raise
KeyError
(
'best checkpoint eval metric name {} is not valid. '
'old_logs: {}, new_logs: {}'
.
format
(
self
.
_metric_name
,
old_logs
,
new_logs
))
old_value
=
float
(
orbit
.
utils
.
get_value
(
old_logs
[
self
.
_metric_name
]))
new_value
=
float
(
orbit
.
utils
.
get_value
(
new_logs
[
self
.
_metric_name
]))
logging
.
info
(
'[BestCheckpointExporter] comparing results. old: %f, new: %f'
,
old_value
,
new_value
)
if
self
.
_metric_comp
==
'higher'
:
if
new_value
>
old_value
:
logging
.
info
(
'[BestCheckpointExporter] '
'the new number is better since it is higher.'
)
return
True
else
:
# self._metric_comp == 'lower':
if
new_value
<
old_value
:
logging
.
info
(
'[BestCheckpointExporter] '
'the new number is better since it is lower.'
)
return
True
return
False
def
_export_best_eval_metric
(
self
,
checkpoint
,
eval_logs
,
global_step
):
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext
=
copy
.
copy
(
eval_logs
)
eval_logs_ext
[
'best_ckpt_global_step'
]
=
global_step
for
name
,
value
in
eval_logs_ext
.
items
():
eval_logs_ext
[
name
]
=
str
(
orbit
.
utils
.
get_value
(
value
))
# Saving json file is very fast.
with
tf
.
io
.
gfile
.
GFile
(
self
.
best_ckpt_logs_path
,
'w'
)
as
writer
:
writer
.
write
(
json
.
dumps
(
eval_logs_ext
,
indent
=
4
)
+
'
\n
'
)
# Saving the best checkpoint might be interrupted if the job got killed.
for
file_to_remove
in
tf
.
io
.
gfile
.
glob
(
self
.
best_ckpt_path
+
'*'
):
tf
.
io
.
gfile
.
rmtree
(
file_to_remove
)
checkpoint
.
save
(
self
.
best_ckpt_path
)
@
property
def
best_ckpt_logs
(
self
):
return
self
.
_best_ckpt_logs
@
property
def
best_ckpt_logs_path
(
self
):
return
os
.
path
.
join
(
self
.
_export_dir
,
'info.json'
)
@
property
def
best_ckpt_path
(
self
):
return
os
.
path
.
join
(
self
.
_export_dir
,
'best_ckpt'
)
BestCheckpointExporter
=
train_utils
.
BestCheckpointExporter
def
maybe_create_best_ckpt_exporter
(
params
:
config_definitions
.
ExperimentConfig
,
...
...
@@ -129,8 +38,8 @@ def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
metric_comp
=
params
.
trainer
.
best_checkpoint_metric_comp
if
data_dir
and
export_subdir
and
metric_name
:
best_ckpt_dir
=
os
.
path
.
join
(
data_dir
,
export_subdir
)
best_ckpt_exporter
=
BestCheckpointExporter
(
best_ckpt_dir
,
metric_name
,
metric_comp
)
best_ckpt_exporter
=
BestCheckpointExporter
(
best_ckpt_dir
,
metric_name
,
metric_comp
)
else
:
best_ckpt_exporter
=
None
logging
.
info
(
...
...
official/core/train_utils.py
View file @
44cfd95e
...
...
@@ -14,14 +14,15 @@
# limitations under the License.
# ==============================================================================
"""Training utils."""
import
copy
import
json
import
os
import
pprint
from
typing
import
Any
,
List
from
typing
import
List
,
Optional
from
absl
import
logging
import
dataclasses
import
gin
import
orbit
import
tensorflow
as
tf
...
...
@@ -32,16 +33,109 @@ from official.core import exp_factory
from
official.modeling
import
hyperparams
class
BestCheckpointExporter
:
"""Keeps track of the best result, and saves its checkpoint.
Orbit will support an API for checkpoint exporter. This class will be used
together with orbit once this functionality is ready.
"""
def
__init__
(
self
,
export_dir
:
str
,
metric_name
:
str
,
metric_comp
:
str
):
"""Initialization.
Arguments:
export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which
result is better.
metric_comp: Indicates how to compare results. Either `lower` or `higher`.
"""
self
.
_export_dir
=
export_dir
self
.
_metric_name
=
metric_name
self
.
_metric_comp
=
metric_comp
if
self
.
_metric_comp
not
in
(
'lower'
,
'higher'
):
raise
ValueError
(
'best checkpoint metric comp must be one of '
'higher, lower. Got: {}'
.
format
(
self
.
_metric_comp
))
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
self
.
best_ckpt_logs_path
))
self
.
_best_ckpt_logs
=
self
.
_maybe_load_best_eval_metric
()
def
maybe_export_checkpoint
(
self
,
checkpoint
,
eval_logs
,
global_step
):
logging
.
info
(
'[BestCheckpointExporter] received eval_logs: %s, at step: %d'
,
eval_logs
,
global_step
)
if
self
.
_best_ckpt_logs
is
None
or
self
.
_new_metric_is_better
(
self
.
_best_ckpt_logs
,
eval_logs
):
self
.
_best_ckpt_logs
=
eval_logs
self
.
_export_best_eval_metric
(
checkpoint
,
self
.
_best_ckpt_logs
,
global_step
)
def
_maybe_load_best_eval_metric
(
self
):
if
not
tf
.
io
.
gfile
.
exists
(
self
.
best_ckpt_logs_path
):
return
None
with
tf
.
io
.
gfile
.
GFile
(
self
.
best_ckpt_logs_path
,
'r'
)
as
reader
:
return
json
.
loads
(
reader
.
read
())
def
_new_metric_is_better
(
self
,
old_logs
,
new_logs
):
"""Check if the metric in new_logs is better than the metric in old_logs."""
if
self
.
_metric_name
not
in
old_logs
or
self
.
_metric_name
not
in
new_logs
:
raise
KeyError
(
'best checkpoint eval metric name {} is not valid. '
'old_logs: {}, new_logs: {}'
.
format
(
self
.
_metric_name
,
old_logs
,
new_logs
))
old_value
=
float
(
orbit
.
utils
.
get_value
(
old_logs
[
self
.
_metric_name
]))
new_value
=
float
(
orbit
.
utils
.
get_value
(
new_logs
[
self
.
_metric_name
]))
logging
.
info
(
'[BestCheckpointExporter] comparing results. old: %f, new: %f'
,
old_value
,
new_value
)
if
self
.
_metric_comp
==
'higher'
:
if
new_value
>
old_value
:
logging
.
info
(
'[BestCheckpointExporter] '
'the new number is better since it is higher.'
)
return
True
else
:
# self._metric_comp == 'lower':
if
new_value
<
old_value
:
logging
.
info
(
'[BestCheckpointExporter] '
'the new number is better since it is lower.'
)
return
True
return
False
def
_export_best_eval_metric
(
self
,
checkpoint
,
eval_logs
,
global_step
):
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext
=
copy
.
copy
(
eval_logs
)
eval_logs_ext
[
'best_ckpt_global_step'
]
=
global_step
for
name
,
value
in
eval_logs_ext
.
items
():
eval_logs_ext
[
name
]
=
str
(
orbit
.
utils
.
get_value
(
value
))
# Saving json file is very fast.
with
tf
.
io
.
gfile
.
GFile
(
self
.
best_ckpt_logs_path
,
'w'
)
as
writer
:
writer
.
write
(
json
.
dumps
(
eval_logs_ext
,
indent
=
4
)
+
'
\n
'
)
# Saving the best checkpoint might be interrupted if the job got killed.
for
file_to_remove
in
tf
.
io
.
gfile
.
glob
(
self
.
best_ckpt_path
+
'*'
):
tf
.
io
.
gfile
.
rmtree
(
file_to_remove
)
checkpoint
.
save
(
self
.
best_ckpt_path
)
@
property
def
best_ckpt_logs
(
self
):
return
self
.
_best_ckpt_logs
@
property
def
best_ckpt_logs_path
(
self
):
return
os
.
path
.
join
(
self
.
_export_dir
,
'info.json'
)
@
property
def
best_ckpt_path
(
self
):
return
os
.
path
.
join
(
self
.
_export_dir
,
'best_ckpt'
)
@
gin
.
configurable
def
create_trainer
(
params
:
config_definitions
.
ExperimentConfig
,
task
:
base_task
.
Task
,
train
:
bool
,
evaluate
:
bool
,
checkpoint_exporter
:
Any
=
None
)
->
base_trainer
.
Trainer
:
checkpoint_exporter
:
Optional
[
BestCheckpointExporter
]
=
None
,
trainer_cls
=
base_trainer
.
Trainer
)
->
base_trainer
.
Trainer
:
"""Create trainer."""
logging
.
info
(
'Running default trainer.'
)
model
=
task
.
build_model
()
optimizer
=
base_trainer
.
create_optimizer
(
params
.
trainer
,
params
.
runtime
)
trainer
=
base_trainer
.
T
rainer
(
return
t
rainer
_cls
(
params
,
task
,
model
=
model
,
...
...
@@ -49,7 +143,6 @@ def create_trainer(params: config_definitions.ExperimentConfig,
train
=
train
,
evaluate
=
evaluate
,
checkpoint_exporter
=
checkpoint_exporter
)
return
trainer
@
dataclasses
.
dataclass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment