Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e54fcee2
Commit
e54fcee2
authored
Jul 24, 2020
by
Ruoxin Sang
Committed by
A. Unique TensorFlower
Jul 24, 2020
Browse files
Internal change
PiperOrigin-RevId: 323098007
parent
a78b05b9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
63 additions
and
48 deletions
+63
-48
orbit/controller.py
orbit/controller.py
+9
-18
orbit/controller_test.py
orbit/controller_test.py
+54
-30
No files found.
orbit/controller.py
View file @
e54fcee2
...
...
@@ -30,14 +30,6 @@ def _log_info(message: Text):
print
(
message
)
def
_validate_interval
(
interval
:
Optional
[
int
],
steps_per_loop
:
Optional
[
int
],
interval_name
:
str
):
if
interval
and
steps_per_loop
and
(
interval
%
steps_per_loop
!=
0
):
raise
ValueError
(
"The {} interval ({}) must be a multiple "
"of the steps_per_loop ({})"
.
format
(
interval_name
,
interval
,
steps_per_loop
))
class
Controller
:
"""Class that facilitates training and evaluation of models."""
...
...
@@ -103,8 +95,10 @@ class Controller:
if
summary_interval
is
not
None
:
if
summary_interval
<=
0
:
raise
ValueError
(
"`summary_interval` should be larger than 0"
)
_validate_interval
(
summary_interval
,
steps_per_loop
,
interval_name
=
"summary"
)
if
summary_interval
%
steps_per_loop
!=
0
:
raise
ValueError
(
"The summary interval ({}) must be a multiple "
"of the steps_per_loop ({})"
.
format
(
summary_interval
,
steps_per_loop
))
self
.
trainer
=
trainer
self
.
evaluator
=
evaluator
...
...
@@ -142,9 +136,6 @@ class Controller:
# TODO(momernick): We probably only want to do this on certain occasions?
if
self
.
checkpoint_manager
is
not
None
:
checkpoint_interval
=
self
.
checkpoint_manager
.
checkpoint_interval
_validate_interval
(
checkpoint_interval
,
steps_per_loop
,
interval_name
=
"checkpoint"
)
model_restored
=
self
.
restore_checkpoint
()
if
not
model_restored
and
(
checkpoint_interval
and
self
.
trainer
is
not
None
):
...
...
@@ -271,15 +262,15 @@ class Controller:
train_steps: The global step count to train up to.
eval_steps: The number of steps to run during an evaluation. If None,
this method will evaluate over the entire evaluation dataset.
eval_interval: The number of training steps to run between evalutions.
Must be a multiple of the controller's `steps_per_loop` init arg. If
None, evaluation will only be performed after training is complete.
eval_interval: The number of training steps to run between evaluations.
If set, training will always stop every `eval_interval` steps, even if
this results in a shorter inner loop than specified by `steps_per_loop`
setting. If None, evaluation will only be performed after training is
complete.
Raises:
ValueError: If eval_interval is not a multiple of self.steps_per_loop.
"""
_validate_interval
(
eval_interval
,
self
.
steps_per_loop
,
interval_name
=
"eval"
)
current_step
=
self
.
global_step
.
numpy
()
# This is an expensive access.
eval_interval
=
eval_interval
or
(
train_steps
-
current_step
)
while
current_step
<
train_steps
:
...
...
orbit/controller_test.py
View file @
e54fcee2
...
...
@@ -33,19 +33,15 @@ def create_model():
def
summaries_with_matching_keyword
(
keyword
,
summary_dir
):
"""Yields summary protos matching given keyword from event file."""
"""Returns summary protos matching given keyword from event file."""
matches
=
[]
event_paths
=
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
summary_dir
,
"events*"
))
for
event
in
tf
.
compat
.
v1
.
train
.
summary_iterator
(
event_paths
[
-
1
]):
if
event
.
summary
is
not
None
:
for
value
in
event
.
summary
.
value
:
if
keyword
in
value
.
tag
:
logging
.
info
(
event
)
yield
event
.
summary
def
check_eventfile_for_keyword
(
keyword
,
summary_dir
):
"""Checks event files for the keyword."""
return
any
(
summaries_with_matching_keyword
(
keyword
,
summary_dir
))
matches
.
append
(
event
.
summary
)
return
matches
def
dataset_fn
(
ctx
):
...
...
@@ -219,13 +215,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Loss and accuracy values should be written into summaries.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assert
True
(
check_eventfile_for
_keyword
(
self
.
assert
NotEmpty
(
summaries_with_matching
_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
self
.
assert
True
(
check_eventfile_for
_keyword
(
self
.
assert
NotEmpty
(
summaries_with_matching
_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
# No checkpoint, so global step starts from 0.
test_runner
.
global_step
.
assign
(
0
)
...
...
@@ -275,13 +271,13 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Loss and accuracy values should be written into summaries.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assert
True
(
check_eventfile_for
_keyword
(
self
.
assert
NotEmpty
(
summaries_with_matching
_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
self
.
assert
True
(
check_eventfile_for
_keyword
(
self
.
assert
NotEmpty
(
summaries_with_matching
_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
def
test_train_only
(
self
):
...
...
@@ -311,8 +307,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Only train summaries are written.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assert
True
(
check_eventfile_for
_keyword
(
self
.
assert
NotEmpty
(
summaries_with_matching
_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertFalse
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
...
...
@@ -340,8 +336,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
self
.
assert
True
(
check_eventfile_for
_keyword
(
self
.
assert
NotEmpty
(
summaries_with_matching
_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
# Tests continuous eval with timeout and timeout_fn.
...
...
@@ -423,8 +419,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Only train summaries are written.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assert
True
(
check_eventfile_for
_keyword
(
self
.
assert
NotEmpty
(
summaries_with_matching
_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
)))
self
.
assertFalse
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
...
...
@@ -453,12 +449,12 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Loss and accuracy values should be written into summaries.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries"
)))
self
.
assert
True
(
check_eventfile_for
_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries"
)))
self
.
assert
True
(
check_eventfile_for_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries"
)))
self
.
assert
NotEmpty
(
summaries_with_matching
_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries"
)))
self
.
assert
NotEmpty
(
summaries_with_matching_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries"
)))
def
test_early_stop_on_eval_loss
(
self
):
test_runner
=
TestRunner
()
...
...
@@ -518,8 +514,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
# Only eval summaries are written
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
self
.
assert
True
(
check_eventfile_for
_keyword
(
self
.
assert
NotEmpty
(
summaries_with_matching
_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
def
test_train_and_evaluate_reset_datasets
(
self
):
...
...
@@ -546,5 +542,33 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
test_controller
.
train_and_evaluate
(
train_steps
=
10
,
eval_steps
=
2
,
eval_interval
=
6
)
def
test_eval_and_checkpoint_interval
(
self
):
test_runner
=
TestRunner
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runner
.
model
,
optimizer
=
test_runner
.
optimizer
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runner
.
global_step
,
checkpoint_interval
=
5
)
test_controller
=
controller
.
Controller
(
trainer
=
test_runner
,
evaluator
=
test_runner
,
global_step
=
test_runner
.
global_step
,
steps_per_loop
=
10
,
checkpoint_manager
=
checkpoint_manager
)
test_controller
.
train_and_evaluate
(
train_steps
=
10
,
eval_steps
=
2
,
eval_interval
=
5
)
# Expect 3 checkpoints to be saved at step: 0, 5, 10.
self
.
assertLen
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
self
.
model_dir
,
"ckpt-*.data*"
)),
3
)
# Expect evaluation is performed 2 times at step: 5, 10.
self
.
assertLen
(
summaries_with_matching_keyword
(
"eval_loss"
,
self
.
model_dir
),
2
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment