Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
85a48cba
Commit
85a48cba
authored
Jun 16, 2021
by
Dan Holtmann-Rice
Committed by
A. Unique TensorFlower
Jun 16, 2021
Browse files
Internal change
PiperOrigin-RevId: 379875442
parent
cf5732d4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
16 deletions
+35
-16
orbit/actions/export_saved_model.py
orbit/actions/export_saved_model.py
+18
-16
orbit/actions/export_saved_model_test.py
orbit/actions/export_saved_model_test.py
+17
-0
No files found.
orbit/actions/export_saved_model.py
View file @
85a48cba
...
@@ -14,24 +14,32 @@
...
@@ -14,24 +14,32 @@
"""Provides the `ExportSavedModel` action and associated helper classes."""
"""Provides the `ExportSavedModel` action and associated helper classes."""
import
re
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
import
tensorflow
as
tf
import
tensorflow
as
tf
def
_id_key
(
filename
):
_
,
id_num
=
filename
.
rsplit
(
'-'
,
maxsplit
=
1
)
return
int
(
id_num
)
def
_find_managed_files
(
base_name
):
r
"""Returns all files matching '{base_name}-\d+', in sorted order."""
managed_file_regex
=
re
.
compile
(
rf
'
{
re
.
escape
(
base_name
)
}
-\d+$'
)
filenames
=
tf
.
io
.
gfile
.
glob
(
f
'
{
base_name
}
-*'
)
filenames
=
filter
(
managed_file_regex
.
match
,
filenames
)
return
sorted
(
filenames
,
key
=
_id_key
)
class
_CounterIdFn
:
class
_CounterIdFn
:
"""Implements a counter-based ID function for `ExportFileManager`."""
"""Implements a counter-based ID function for `ExportFileManager`."""
def
__init__
(
self
,
base_name
:
str
):
def
__init__
(
self
,
base_name
:
str
):
filenames
=
tf
.
io
.
gfile
.
glob
(
f
'
{
base_name
}
-*'
)
managed_files
=
_find_managed_files
(
base_name
)
max_counter
=
-
1
self
.
value
=
_id_key
(
managed_files
[
-
1
])
+
1
if
managed_files
else
0
for
filename
in
filenames
:
try
:
_
,
file_number
=
filename
.
rsplit
(
'-'
,
maxsplit
=
1
)
max_counter
=
max
(
max_counter
,
int
(
file_number
))
except
ValueError
:
continue
self
.
value
=
max_counter
+
1
def
__call__
(
self
):
def
__call__
(
self
):
output
=
self
.
value
output
=
self
.
value
...
@@ -82,13 +90,7 @@ class ExportFileManager:
...
@@ -82,13 +90,7 @@ class ExportFileManager:
`ExportFileManager` instance, sorted in increasing integer order of the
`ExportFileManager` instance, sorted in increasing integer order of the
IDs returned by `next_id_fn`.
IDs returned by `next_id_fn`.
"""
"""
return
_find_managed_files
(
self
.
_base_name
)
def
id_key
(
name
):
_
,
id_num
=
name
.
rsplit
(
'-'
,
maxsplit
=
1
)
return
int
(
id_num
)
filenames
=
tf
.
io
.
gfile
.
glob
(
f
'
{
self
.
_base_name
}
-*'
)
return
sorted
(
filenames
,
key
=
id_key
)
def
clean_up
(
self
):
def
clean_up
(
self
):
"""Cleans up old files matching `{base_name}-*`.
"""Cleans up old files matching `{base_name}-*`.
...
...
orbit/actions/export_saved_model_test.py
View file @
85a48cba
...
@@ -105,6 +105,23 @@ class ExportSavedModelTest(tf.test.TestCase):
...
@@ -105,6 +105,23 @@ class ExportSavedModelTest(tf.test.TestCase):
_id_sorted_file_base_names
(
directory
.
full_path
),
_id_sorted_file_base_names
(
directory
.
full_path
),
[
'basename-200'
,
'basename-1000'
])
[
'basename-200'
,
'basename-1000'
])
def
test_export_file_manager_managed_files
(
self
):
directory
=
self
.
create_tempdir
()
directory
.
create_file
(
'basename-5'
)
directory
.
create_file
(
'basename-10'
)
directory
.
create_file
(
'basename-50'
)
directory
.
create_file
(
'basename-1000'
)
directory
.
create_file
(
'basename-9'
)
directory
.
create_file
(
'basename-10-suffix'
)
base_name
=
os
.
path
.
join
(
directory
.
full_path
,
'basename'
)
manager
=
actions
.
ExportFileManager
(
base_name
,
max_to_keep
=
3
)
self
.
assertLen
(
manager
.
managed_files
,
5
)
self
.
assertEqual
(
manager
.
next_name
(),
f
'
{
base_name
}
-1001'
)
manager
.
clean_up
()
self
.
assertEqual
(
manager
.
managed_files
,
[
f
'
{
base_name
}
-10'
,
f
'
{
base_name
}
-50'
,
f
'
{
base_name
}
-1000'
])
def
test_export_saved_model
(
self
):
def
test_export_saved_model
(
self
):
directory
=
self
.
create_tempdir
()
directory
=
self
.
create_tempdir
()
base_name
=
os
.
path
.
join
(
directory
.
full_path
,
'basename'
)
base_name
=
os
.
path
.
join
(
directory
.
full_path
,
'basename'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment