Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c508968c
Commit
c508968c
authored
Feb 11, 2021
by
Chen Chen
Committed by
A. Unique TensorFlower
Feb 11, 2021
Browse files
Internal change
PiperOrigin-RevId: 357078424
parent
b9b0be18
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
5 deletions
+17
-5
official/core/train_utils.py
official/core/train_utils.py
+17
-5
No files found.
official/core/train_utils.py
View file @
c508968c
...
...
@@ -56,6 +56,20 @@ class BestCheckpointExporter:
'higher, lower. Got: {}'
.
format
(
self
.
_metric_comp
))
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
self
.
best_ckpt_logs_path
))
self
.
_best_ckpt_logs
=
self
.
_maybe_load_best_eval_metric
()
self
.
_checkpoint_manager
=
None
def
_get_checkpoint_manager
(
self
,
checkpoint
):
"""Gets an existing checkpoint manager or creates a new one."""
if
self
.
_checkpoint_manager
is
None
or
(
self
.
_checkpoint_manager
.
checkpoint
!=
checkpoint
):
logging
.
info
(
'Creates a new checkpoint manager.'
)
self
.
_checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
self
.
_export_dir
,
max_to_keep
=
1
,
checkpoint_name
=
'best_ckpt'
)
return
self
.
_checkpoint_manager
def
maybe_export_checkpoint
(
self
,
checkpoint
,
eval_logs
,
global_step
):
logging
.
info
(
'[BestCheckpointExporter] received eval_logs: %s, at step: %d'
,
...
...
@@ -105,10 +119,7 @@ class BestCheckpointExporter:
with
tf
.
io
.
gfile
.
GFile
(
self
.
best_ckpt_logs_path
,
'w'
)
as
writer
:
writer
.
write
(
json
.
dumps
(
eval_logs_ext
,
indent
=
4
)
+
'
\n
'
)
# Saving the best checkpoint might be interrupted if the job got killed.
for
file_to_remove
in
tf
.
io
.
gfile
.
glob
(
self
.
best_ckpt_path
+
'*'
):
tf
.
io
.
gfile
.
remove
(
file_to_remove
)
checkpoint
.
write
(
self
.
best_ckpt_path
)
self
.
_get_checkpoint_manager
(
checkpoint
).
save
()
@
property
def
best_ckpt_logs
(
self
):
...
...
@@ -120,7 +131,8 @@ class BestCheckpointExporter:
@
property
def
best_ckpt_path
(
self
):
return
os
.
path
.
join
(
self
.
_export_dir
,
'best_ckpt'
)
"""Returns the best ckpt path or None if there is no ckpt yet."""
return
tf
.
train
.
latest_checkpoint
(
self
.
_export_dir
)
@
gin
.
configurable
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment