Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f69ef1cd
Commit
f69ef1cd
authored
Jul 23, 2021
by
Frederick Liu
Committed by
A. Unique TensorFlower
Jul 23, 2021
Browse files
Internal change
PiperOrigin-RevId: 386531154
parent
7c2ff1af
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
6 deletions
+9
-6
official/core/train_utils.py
official/core/train_utils.py
+9
-6
No files found.
official/core/train_utils.py
View file @
f69ef1cd
...
...
@@ -142,14 +142,19 @@ class BestCheckpointExporter:
return
self
.
_checkpoint_manager
def
maybe_export_checkpoint
(
self
,
checkpoint
,
eval_logs
,
global_step
):
def
maybe_export_checkpoint
(
self
,
checkpoint
,
eval_logs
,
global_step
,
write_logs
=
True
)
->
bool
:
"""Compare eval_logs with past eval_logs and export checkpoint if better."""
logging
.
info
(
'[BestCheckpointExporter] received eval_logs: %s, at step: %d'
,
eval_logs
,
global_step
)
if
self
.
_best_ckpt_logs
is
None
or
self
.
_new_metric_is_better
(
self
.
_best_ckpt_logs
,
eval_logs
):
self
.
_best_ckpt_logs
=
eval_logs
self
.
_export_best_eval_metric
(
checkpoint
,
self
.
_best_ckpt_logs
,
global_step
)
if
write_logs
:
self
.
export_best_eval_metric
(
self
.
_best_ckpt_logs
,
global_step
)
self
.
_get_checkpoint_manager
(
checkpoint
).
save
()
return
True
return
False
def
_maybe_load_best_eval_metric
(
self
):
if
not
tf
.
io
.
gfile
.
exists
(
self
.
best_ckpt_logs_path
):
...
...
@@ -180,7 +185,7 @@ class BestCheckpointExporter:
return
True
return
False
def
_
export_best_eval_metric
(
self
,
checkpoint
,
eval_logs
,
global_step
):
def
export_best_eval_metric
(
self
,
eval_logs
,
global_step
):
"""Export evaluation results of the best checkpoint into a json file."""
eval_logs_ext
=
copy
.
copy
(
eval_logs
)
eval_logs_ext
[
'best_ckpt_global_step'
]
=
global_step
...
...
@@ -190,8 +195,6 @@ class BestCheckpointExporter:
with
tf
.
io
.
gfile
.
GFile
(
self
.
best_ckpt_logs_path
,
'w'
)
as
writer
:
writer
.
write
(
json
.
dumps
(
eval_logs_ext
,
indent
=
4
)
+
'
\n
'
)
self
.
_get_checkpoint_manager
(
checkpoint
).
save
()
@
property
def
best_ckpt_logs
(
self
):
return
self
.
_best_ckpt_logs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment