Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
13d44a05
Commit
13d44a05
authored
Mar 22, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Mar 22, 2020
Browse files
Fix controller bugs. Add tests for optional args.
PiperOrigin-RevId: 302323163
parent
a348a90b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
60 additions
and
7 deletions
+60
-7
official/staging/training/controller.py
official/staging/training/controller.py
+14
-7
official/staging/training/controller_test.py
official/staging/training/controller_test.py
+46
-0
No files found.
official/staging/training/controller.py
View file @
13d44a05
...
...
@@ -117,11 +117,18 @@ class Controller(object):
if
self
.
train_fn
is
not
None
:
self
.
train_steps
=
train_steps
self
.
steps_per_loop
=
steps_per_loop
self
.
summary_dir
=
summary_dir
or
checkpoint_manager
.
directory
if
summary_dir
:
self
.
summary_dir
=
summary_dir
elif
checkpoint_manager
:
self
.
summary_dir
=
checkpoint_manager
.
directory
else
:
self
.
summary_dir
=
None
self
.
summary_interval
=
summary_interval
summary_writer
=
tf
.
summary
.
create_file_writer
(
self
.
summary_dir
)
if
self
.
summary_interval
else
None
if
self
.
summary_dir
and
self
.
summary_interval
:
summary_writer
=
tf
.
summary
.
create_file_writer
(
self
.
summary_dir
)
else
:
summary_writer
=
None
# TODO(rxsang): Consider pass SummaryManager directly into Controller for
# maximum customizability.
self
.
summary_manager
=
utils
.
SummaryManager
(
...
...
@@ -140,14 +147,14 @@ class Controller(object):
self
.
eval_steps
=
eval_steps
self
.
eval_interval
=
eval_interval
# Create and initialize the interval triggers.
# Create
s
and initialize
s
the interval triggers.
self
.
eval_trigger
=
utils
.
IntervalTrigger
(
self
.
eval_interval
,
self
.
global_step
.
numpy
())
self
.
global_step
.
numpy
())
# pytype: disable=attribute-error
if
self
.
global_step
:
tf
.
summary
.
experimental
.
set_step
(
self
.
global_step
)
# Restore
M
odel if needed.
# Restore
s the m
odel if needed.
if
self
.
checkpoint_manager
is
not
None
:
model_restored
=
self
.
_restore_model
()
if
not
model_restored
and
self
.
checkpoint_manager
.
checkpoint_interval
:
...
...
@@ -192,7 +199,7 @@ class Controller(object):
self
.
eval_summary_manager
.
flush
()
def
_maybe_save_checkpoints
(
self
,
current_step
,
force_trigger
=
False
):
if
self
.
checkpoint_manager
.
checkpoint_interval
:
if
self
.
checkpoint_manager
and
self
.
checkpoint_manager
.
checkpoint_interval
:
ckpt_path
=
self
.
checkpoint_manager
.
save
(
checkpoint_number
=
current_step
,
check_interval
=
not
force_trigger
)
if
ckpt_path
is
not
None
:
...
...
official/staging/training/controller_test.py
View file @
13d44a05
...
...
@@ -143,6 +143,52 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
super
(
ControllerTest
,
self
).
setUp
()
self
.
model_dir
=
self
.
get_temp_dir
()
def
test_no_checkpoint
(
self
):
test_runnable
=
TestRunnable
()
# No checkpoint manager and no strategy.
test_controller
=
controller
.
Controller
(
train_fn
=
test_runnable
.
train
,
eval_fn
=
test_runnable
.
evaluate
,
global_step
=
test_runnable
.
global_step
,
train_steps
=
10
,
steps_per_loop
=
2
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
summary_interval
=
2
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
),
eval_steps
=
2
,
eval_interval
=
5
)
test_controller
.
train
(
evaluate
=
True
)
self
.
assertEqual
(
test_runnable
.
global_step
.
numpy
(),
10
)
# Loss and accuracy values should be written into summaries.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
self
.
assertTrue
(
check_eventfile_for_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
# No checkpoint, so global step starts from 0.
test_runnable
.
global_step
.
assign
(
0
)
test_controller
.
train
(
evaluate
=
True
)
self
.
assertEqual
(
test_runnable
.
global_step
.
numpy
(),
10
)
def
test_no_checkpoint_and_summaries
(
self
):
test_runnable
=
TestRunnable
()
# No checkpoint + summary directories.
test_controller
=
controller
.
Controller
(
train_fn
=
test_runnable
.
train
,
eval_fn
=
test_runnable
.
evaluate
,
global_step
=
test_runnable
.
global_step
,
train_steps
=
10
,
steps_per_loop
=
2
,
eval_steps
=
2
,
eval_interval
=
5
)
test_controller
.
train
(
evaluate
=
True
)
self
.
assertEqual
(
test_runnable
.
global_step
.
numpy
(),
10
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_train_and_evaluate
(
self
,
strategy
):
with
strategy
.
scope
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment