Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
bf5a3910
Commit
bf5a3910
authored
Aug 22, 2022
by
Chaochao Yan
Committed by
A. Unique TensorFlower
Aug 22, 2022
Browse files
Internal change
PiperOrigin-RevId: 469264237
parent
eac1af65
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
185 additions
and
14 deletions
+185
-14
official/core/savedmodel_checkpoint_manager.py
official/core/savedmodel_checkpoint_manager.py
+128
-9
official/core/savedmodel_checkpoint_manager_test.py
official/core/savedmodel_checkpoint_manager_test.py
+57
-5
No files found.
official/core/savedmodel_checkpoint_manager.py
View file @
bf5a3910
...
@@ -15,16 +15,18 @@
...
@@ -15,16 +15,18 @@
"""Custom checkpoint manager that also exports saved models."""
"""Custom checkpoint manager that also exports saved models."""
import
os
import
os
from
typing
import
Callable
,
Mapping
,
Optional
import
re
import
time
from
typing
import
Callable
,
List
,
Mapping
,
Optional
,
Union
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
_
SAVED_MODULES_PATH_SUFFIX
=
'saved_modules'
SAVED_MODULES_PATH_SUFFIX
=
'saved_modules'
def
make_saved_modules_directory_name
(
checkpoint_name
:
str
)
->
str
:
def
make_saved_modules_directory_name
(
checkpoint_name
:
str
)
->
str
:
return
f
'
{
checkpoint_name
}
_
{
_
SAVED_MODULES_PATH_SUFFIX
}
'
return
f
'
{
checkpoint_name
}
_
{
SAVED_MODULES_PATH_SUFFIX
}
'
class
SavedModelCheckpointManager
(
tf
.
train
.
CheckpointManager
):
class
SavedModelCheckpointManager
(
tf
.
train
.
CheckpointManager
):
...
@@ -51,10 +53,10 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
...
@@ -51,10 +53,10 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
checkpoint_interval
=
checkpoint_interval
,
checkpoint_interval
=
checkpoint_interval
,
init_fn
=
init_fn
)
init_fn
=
init_fn
)
self
.
_modules_to_export
=
modules_to_export
self
.
_modules_to_export
=
modules_to_export
self
.
_savedmodels
=
self
.
_
get_existing_savedmodels
()
self
.
_savedmodels
=
self
.
get_existing_savedmodels
()
def
save
(
self
,
def
save
(
self
,
checkpoint_number
=
None
,
checkpoint_number
:
Optional
[
int
]
=
None
,
check_interval
:
bool
=
True
,
check_interval
:
bool
=
True
,
options
:
Optional
[
tf
.
train
.
CheckpointOptions
]
=
None
):
options
:
Optional
[
tf
.
train
.
CheckpointOptions
]
=
None
):
"""See base class."""
"""See base class."""
...
@@ -80,7 +82,7 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
...
@@ -80,7 +82,7 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
saved_modules_directories_to_keep
=
[
saved_modules_directories_to_keep
=
[
make_saved_modules_directory_name
(
ckpt
)
for
ckpt
in
self
.
checkpoints
make_saved_modules_directory_name
(
ckpt
)
for
ckpt
in
self
.
checkpoints
]
]
existing_saved_modules_dirs
=
self
.
_
get_existing_savedmodels
()
existing_saved_modules_dirs
=
self
.
get_existing_savedmodels
()
self
.
_savedmodels
=
[]
self
.
_savedmodels
=
[]
# Keep savedmodels in the same order as checkpoints (from oldest to newest).
# Keep savedmodels in the same order as checkpoints (from oldest to newest).
...
@@ -94,7 +96,7 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
...
@@ -94,7 +96,7 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
return
checkpoint_path
return
checkpoint_path
def
_
get_existing_savedmodels
(
self
):
def
get_existing_savedmodels
(
self
)
->
List
[
str
]
:
"""Gets a list of all existing SavedModel paths in `directory`.
"""Gets a list of all existing SavedModel paths in `directory`.
Returns:
Returns:
...
@@ -105,7 +107,7 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
...
@@ -105,7 +107,7 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
return
tf
.
io
.
gfile
.
glob
(
saved_modules_glob
)
return
tf
.
io
.
gfile
.
glob
(
saved_modules_glob
)
@
property
@
property
def
latest_savedmodel
(
self
):
def
latest_savedmodel
(
self
)
->
Union
[
str
,
None
]
:
"""The path of the most recent SavedModel in `directory`.
"""The path of the most recent SavedModel in `directory`.
Returns:
Returns:
...
@@ -116,10 +118,127 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
...
@@ -116,10 +118,127 @@ class SavedModelCheckpointManager(tf.train.CheckpointManager):
return
None
return
None
@
property
@
property
def
savedmodels
(
self
):
def
savedmodels
(
self
)
->
List
[
str
]
:
"""A list of managed SavedModels.
"""A list of managed SavedModels.
Returns:
Returns:
A list of SavedModel paths, sorted from oldest to newest.
A list of SavedModel paths, sorted from oldest to newest.
"""
"""
return
self
.
_savedmodels
return
self
.
_savedmodels
@
property
def
modules_to_export
(
self
)
->
Union
[
Mapping
[
str
,
tf
.
Module
],
None
]:
return
self
.
_modules_to_export
def
get_savedmodel_number_from_path
(
self
,
savedmodel_path
:
str
)
->
Union
[
int
,
None
]:
"""Gets the savedmodel_number/checkpoint_number from savedmodel filepath.
The savedmodel_number is global step when using with orbit controller.
Args:
savedmodel_path: savedmodel directory path.
Returns:
Savedmodel number or None if no matched pattern found in savedmodel path.
"""
pattern
=
rf
'\d+_
{
SAVED_MODULES_PATH_SUFFIX
}
$'
savedmodel_number
=
re
.
search
(
pattern
,
savedmodel_path
)
if
savedmodel_number
:
savedmodel_number
=
savedmodel_number
.
group
()
return
int
(
savedmodel_number
[:
-
len
(
SAVED_MODULES_PATH_SUFFIX
)
-
1
])
return
None
def
savedmodels_iterator
(
self
,
min_interval_secs
:
float
=
0
,
timeout
:
Optional
[
float
]
=
None
,
timeout_fn
:
Optional
[
Callable
[[],
bool
]]
=
None
):
"""Continuously yield new SavedModel files as they appear.
The iterator only checks for new savedmodels when control flow has been
reverted to it. The logic is same to the `train.checkpoints_iterator`.
Args:
min_interval_secs: The minimum number of seconds between yielding
savedmodels.
timeout: The maximum number of seconds to wait between savedmodels. If
left as `None`, then the process will wait indefinitely.
timeout_fn: Optional function to call after a timeout. If the function
returns True, then it means that no new savedmodels will be generated
and the iterator will exit. The function is called with no arguments.
Yields:
String paths to latest SavedModel files as they arrive.
"""
savedmodel_path
=
None
while
True
:
new_savedmodel_path
=
self
.
wait_for_new_savedmodel
(
savedmodel_path
,
timeout
=
timeout
)
if
new_savedmodel_path
is
None
:
if
not
timeout_fn
:
# timed out
logging
.
info
(
'Timed-out waiting for a savedmodel.'
)
return
if
timeout_fn
():
# The timeout_fn indicated that we are truly done.
return
else
:
# The timeout_fn indicated that more savedmodels may come.
continue
start
=
time
.
time
()
savedmodel_path
=
new_savedmodel_path
yield
savedmodel_path
time_to_next_eval
=
start
+
min_interval_secs
-
time
.
time
()
if
time_to_next_eval
>
0
:
time
.
sleep
(
time_to_next_eval
)
def
wait_for_new_savedmodel
(
self
,
last_savedmodel
:
Optional
[
str
]
=
None
,
seconds_to_sleep
:
float
=
1.0
,
timeout
:
Optional
[
float
]
=
None
)
->
Union
[
str
,
None
]:
"""Waits until a new savedmodel file is found.
Args:
last_savedmodel: The last savedmodel path used or `None` if we're
expecting a savedmodel for the first time.
seconds_to_sleep: The number of seconds to sleep for before looking for a
new savedmodel.
timeout: The maximum number of seconds to wait. If left as `None`, then
the process will wait indefinitely.
Returns:
A new savedmodel path, or None if the timeout was reached.
"""
logging
.
info
(
'Waiting for new savedmodel at %s'
,
self
.
_directory
)
stop_time
=
time
.
time
()
+
timeout
if
timeout
is
not
None
else
None
last_savedmodel_number
=
0
if
last_savedmodel
:
last_savedmodel_number
=
self
.
get_savedmodel_number_from_path
(
last_savedmodel
)
while
True
:
if
stop_time
is
not
None
and
time
.
time
()
+
seconds_to_sleep
>
stop_time
:
return
None
existing_savedmodels
=
{}
for
savedmodel_path
in
self
.
get_existing_savedmodels
():
savedmodel_number
=
self
.
get_savedmodel_number_from_path
(
savedmodel_path
)
if
savedmodel_number
is
not
None
:
existing_savedmodels
[
savedmodel_number
]
=
savedmodel_path
# Find the first savedmodel with larger step number as next savedmodel.
savedmodel_path
=
None
existing_savedmodels
=
dict
(
sorted
(
existing_savedmodels
.
items
()))
for
savedmodel_number
in
existing_savedmodels
:
if
savedmodel_number
>
last_savedmodel_number
:
savedmodel_path
=
existing_savedmodels
[
savedmodel_number
]
break
if
savedmodel_path
:
logging
.
info
(
'Found new savedmodel at %s'
,
savedmodel_path
)
return
savedmodel_path
else
:
time
.
sleep
(
seconds_to_sleep
)
official/core/savedmodel_checkpoint_manager_test.py
View file @
bf5a3910
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
time
from
typing
import
Iterable
from
typing
import
Iterable
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -32,12 +33,20 @@ def _models_exist(checkpoint_path: str, models: Iterable[str]) -> bool:
...
@@ -32,12 +33,20 @@ def _models_exist(checkpoint_path: str, models: Iterable[str]) -> bool:
class
CheckpointManagerTest
(
tf
.
test
.
TestCase
):
class
CheckpointManagerTest
(
tf
.
test
.
TestCase
):
def
testSimpleTest
(
self
):
def
_create_manager
(
self
,
max_to_keep
:
int
=
1
)
->
tf
.
train
.
CheckpointManager
:
"""Sets up SavedModelCheckpointManager object.
Args:
max_to_keep: max number of savedmodels to keep.
Returns:
created savedmodel manager.
"""
models
=
{
models
=
{
"
model_1
"
:
'
model_1
'
:
tf
.
keras
.
Sequential
(
tf
.
keras
.
Sequential
(
layers
=
[
tf
.
keras
.
layers
.
Dense
(
8
,
input_shape
=
(
16
,))]),
layers
=
[
tf
.
keras
.
layers
.
Dense
(
8
,
input_shape
=
(
16
,))]),
"
model_2
"
:
'
model_2
'
:
tf
.
keras
.
Sequential
(
tf
.
keras
.
Sequential
(
layers
=
[
tf
.
keras
.
layers
.
Dense
(
16
,
input_shape
=
(
32
,))]),
layers
=
[
tf
.
keras
.
layers
.
Dense
(
16
,
input_shape
=
(
32
,))]),
}
}
...
@@ -45,9 +54,13 @@ class CheckpointManagerTest(tf.test.TestCase):
...
@@ -45,9 +54,13 @@ class CheckpointManagerTest(tf.test.TestCase):
manager
=
savedmodel_checkpoint_manager
.
SavedModelCheckpointManager
(
manager
=
savedmodel_checkpoint_manager
.
SavedModelCheckpointManager
(
checkpoint
=
checkpoint
,
checkpoint
=
checkpoint
,
directory
=
self
.
get_temp_dir
(),
directory
=
self
.
get_temp_dir
(),
max_to_keep
=
1
,
max_to_keep
=
max_to_keep
,
modules_to_export
=
models
)
modules_to_export
=
models
)
return
manager
def
test_max_to_keep
(
self
):
manager
=
self
.
_create_manager
()
models
=
manager
.
modules_to_export
first_path
=
manager
.
save
()
first_path
=
manager
.
save
()
second_path
=
manager
.
save
()
second_path
=
manager
.
save
()
...
@@ -57,6 +70,45 @@ class CheckpointManagerTest(tf.test.TestCase):
...
@@ -57,6 +70,45 @@ class CheckpointManagerTest(tf.test.TestCase):
self
.
assertTrue
(
_models_exist
(
second_path
,
models
.
keys
()))
self
.
assertTrue
(
_models_exist
(
second_path
,
models
.
keys
()))
self
.
assertFalse
(
_models_exist
(
first_path
,
models
.
keys
()))
self
.
assertFalse
(
_models_exist
(
first_path
,
models
.
keys
()))
def
test_returns_none_after_timeout
(
self
):
manager
=
self
.
_create_manager
()
start
=
time
.
time
()
ret
=
manager
.
wait_for_new_savedmodel
(
None
,
timeout
=
1.0
,
seconds_to_sleep
=
0.5
)
end
=
time
.
time
()
self
.
assertIsNone
(
ret
)
# We've waited 0.5 second.
self
.
assertGreater
(
end
,
start
+
0.5
)
# The timeout kicked in.
self
.
assertLess
(
end
,
start
+
0.6
)
def
test_saved_model_iterator
(
self
):
manager
=
self
.
_create_manager
(
max_to_keep
=
2
)
self
.
assertIsNotNone
(
manager
.
save
(
checkpoint_number
=
1
))
self
.
assertIsNotNone
(
manager
.
save
(
checkpoint_number
=
2
))
self
.
assertIsNotNone
(
manager
.
save
(
checkpoint_number
=
3
))
# Savedmodels are in time order.
expected_savedmodels
=
manager
.
savedmodels
# Order not guaranteed.
existing_savedmodels
=
manager
.
get_existing_savedmodels
()
savedmodels
=
list
(
manager
.
savedmodels_iterator
(
timeout
=
3.0
))
self
.
assertEqual
(
savedmodels
,
expected_savedmodels
)
self
.
assertEqual
(
set
(
savedmodels
),
set
(
existing_savedmodels
))
def
test_saved_model_iterator_timeout_fn
(
self
):
manager
=
self
.
_create_manager
()
timeout_fn_calls
=
[
0
]
def
timeout_fn
():
timeout_fn_calls
[
0
]
+=
1
return
timeout_fn_calls
[
0
]
>
3
results
=
list
(
manager
.
savedmodels_iterator
(
timeout
=
0.1
,
timeout_fn
=
timeout_fn
))
self
.
assertEqual
([],
results
)
self
.
assertEqual
(
4
,
timeout_fn_calls
[
0
])
if
__name__
==
"
__main__
"
:
if
__name__
==
'
__main__
'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment