Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
77026626
Commit
77026626
authored
Jul 11, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 460263057
parent
7d45e7b9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
14 deletions
+47
-14
official/core/savedmodel_checkpoint_manager.py
official/core/savedmodel_checkpoint_manager.py
+44
-14
official/core/savedmodel_checkpoint_manager_test.py
official/core/savedmodel_checkpoint_manager_test.py
+3
-0
No files found.
official/core/savedmodel_checkpoint_manager.py
View file @
77026626
...
...
@@ -15,15 +15,16 @@
"""Custom checkpoint manager that also exports saved models."""
import
os
import
re
from
typing
import
Callable
,
Mapping
,
Optional
from
absl
import
logging
import
tensorflow
as
tf
_SAVED_MODULES_PATH_SUFFIX
=
'saved_modules'
def
make_saved_modules_directory_name
(
checkpoint_name
:
str
)
->
str
:
return
f
'
{
checkpoint_name
}
_
saved_modules
'
return
f
'
{
checkpoint_name
}
_
{
_SAVED_MODULES_PATH_SUFFIX
}
'
class
SavedModelCheckpointManager
(
tf
.
train
.
CheckpointManager
):
...
...
@@ -50,6 +51,7 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
checkpoint_interval
=
checkpoint_interval
,
init_fn
=
init_fn
)
self
.
_modules_to_export
=
modules_to_export
self
.
_savedmodels
=
self
.
_get_existing_savedmodels
()
def
save
(
self
,
checkpoint_number
=
None
,
...
...
@@ -73,21 +75,49 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
obj
=
model
,
export_dir
=
os
.
path
.
join
(
saved_modules_directory
,
model_name
))
# `checkpoint_path` ends in `-[\d]+`. We want to glob for all existing
# checkpoints, and we use the .index file for that.
checkpoint_glob
=
re
.
sub
(
r
'\d+$'
,
'*.index'
,
checkpoint_path
)
existing_checkpoint_files
=
tf
.
io
.
gfile
.
glob
(
checkpoint_glob
)
saved_modules_directories_to_keep
=
[
make_saved_modules_directory_name
(
os
.
path
.
splitext
(
ckpt_index
)[
0
])
for
ckpt_index
in
existing_checkpoint_files
make_saved_modules_directory_name
(
ckpt
)
for
ckpt
in
self
.
checkpoints
]
saved_modules_glob
=
re
.
sub
(
r
'\d+_saved_modules$'
,
'*_saved_modules'
,
saved_modules_directory
)
existing_saved_modules_dirs
=
self
.
_get_existing_savedmodels
()
self
.
_savedmodels
=
[]
# Keep savedmodels in the same order as checkpoints (from oldest to newest).
for
saved_modules_dir_to_keep
in
saved_modules_directories_to_keep
:
if
saved_modules_dir_to_keep
in
existing_saved_modules_dirs
:
self
.
_savedmodels
.
append
(
saved_modules_dir_to_keep
)
for
existing_saved_modules_dir
in
tf
.
io
.
gfile
.
glob
(
saved_modules_glob
):
if
(
existing_saved_modules_dir
not
in
saved_modules_directories_to_keep
and
tf
.
io
.
gfile
.
isdir
(
existing_saved_modules_dir
)):
for
existing_saved_modules_dir
in
existing_saved_modules_dirs
:
if
existing_saved_modules_dir
not
in
self
.
_savedmodels
:
tf
.
io
.
gfile
.
rmtree
(
existing_saved_modules_dir
)
return
checkpoint_path
def
_get_existing_savedmodels
(
self
):
"""Gets a list of all existing SavedModel paths in `directory`.
Returns:
A list of all existing SavedModel paths.
"""
saved_modules_glob
=
make_saved_modules_directory_name
(
self
.
_checkpoint_prefix
+
'-*'
)
return
tf
.
io
.
gfile
.
glob
(
saved_modules_glob
)
@
property
def
latest_savedmodel
(
self
):
"""The path of the most recent SavedModel in `directory`.
Returns:
The latest SavedModel path. If there are no SavedModels, returns `None`.
"""
if
self
.
_savedmodels
:
return
self
.
_savedmodels
[
-
1
]
return
None
@
property
def
savedmodels
(
self
):
"""A list of managed SavedModels.
Returns:
A list of SavedModel paths, sorted from oldest to newest.
"""
return
self
.
_savedmodels
official/core/savedmodel_checkpoint_manager_test.py
View file @
77026626
...
...
@@ -51,6 +51,9 @@ class CheckpointManagerTest(tf.test.TestCase):
first_path
=
manager
.
save
()
second_path
=
manager
.
save
()
savedmodel
=
savedmodel_checkpoint_manager
.
make_saved_modules_directory_name
(
manager
.
latest_checkpoint
)
self
.
assertEqual
(
savedmodel
,
manager
.
latest_savedmodel
)
self
.
assertTrue
(
_models_exist
(
second_path
,
models
.
keys
()))
self
.
assertFalse
(
_models_exist
(
first_path
,
models
.
keys
()))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment