Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d41ed934
Unverified
Commit
d41ed934
authored
May 30, 2018
by
Yanhui Liang
Committed by
GitHub
May 30, 2018
Browse files
Fix hooks_test for examples/second hook (#4411)
* Fix hooks_test * Add more comments * Fix lints
parent
04c81871
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
55 deletions
+66
-55
official/utils/logs/hooks_test.py
official/utils/logs/hooks_test.py
+66
-55
No files found.
official/utils/logs/hooks_test.py
View file @
d41ed934
...
...
@@ -22,17 +22,25 @@ from __future__ import print_function
import
time
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
tensorflow.python.training
import
monitored_session
# pylint: disable=g-bad-import-order
from
official.utils.logs
import
hooks
from
official.utils.testing
import
mock_lib
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
DEBUG
)
class
ExamplesPerSecondHookTest
(
tf
.
test
.
TestCase
):
"""Tests for the ExamplesPerSecondHook."""
"""Tests for the ExamplesPerSecondHook.
In this test, we explicitly run global_step tensor after train_op in order to
grab the correct global step value. This is to correct for discrepancies in
reported global step when running on GPUs. As in the after_run functions in
ExamplesPerSecondHook, the global step from run_results
(global_step = run_values.results) is not always correct and taken as the
stale global_step (which may be 1 off the correct value). The exact
global_step value should be from run_context
(global_step = run_context.session.run(global_step_tensor)
"""
def
setUp
(
self
):
"""Mock out logging calls to verify if correct info is being monitored."""
...
...
@@ -40,8 +48,9 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
self
.
global_step
=
tf
.
train
.
get_or_create_global_step
()
self
.
train_op
=
tf
.
assign_add
(
self
.
global_step
,
1
)
tf
.
train
.
create_global_step
()
self
.
train_op
=
tf
.
assign_add
(
tf
.
train
.
get_global_step
(),
1
)
self
.
global_step
=
tf
.
train
.
get_global_step
()
def
test_raise_in_both_secs_and_steps
(
self
):
with
self
.
assertRaises
(
ValueError
):
...
...
@@ -59,86 +68,88 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
every_n_secs
=
None
,
metric_logger
=
self
.
_logger
)
def
_validate_log_every_n_steps
(
self
,
sess
,
every_n_steps
,
warm_steps
):
def
_validate_log_every_n_steps
(
self
,
every_n_steps
,
warm_steps
):
hook
=
hooks
.
ExamplesPerSecondHook
(
batch_size
=
256
,
every_n_steps
=
every_n_steps
,
warm_steps
=
warm_steps
,
metric_logger
=
self
.
_logger
)
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
global_variables_initializer
())
for
_
in
range
(
every_n_steps
):
with
tf
.
train
.
MonitoredSession
(
tf
.
train
.
ChiefSessionCreator
(),
[
hook
])
as
mon_sess
:
for
_
in
range
(
every_n_steps
):
# Explicitly run global_step after train_op to get the accurate
# global_step value
mon_sess
.
run
(
self
.
train_op
)
mon_sess
.
run
(
self
.
global_step
)
# Nothing should be in the list yet
self
.
assertFalse
(
self
.
_logger
.
logged_metric
)
mon_sess
.
run
(
self
.
train_op
)
# Nothing should be in the list yet
self
.
assertFalse
(
self
.
_logger
.
logged_metric
)
global_step_val
=
mon_sess
.
run
(
self
.
global_step
)
mon_sess
.
run
(
self
.
train_op
)
global_step_val
=
sess
.
run
(
self
.
global_step
)
if
global_step_val
>
warm_steps
:
self
.
_assert_metrics
()
else
:
# Nothing should be in the list yet
self
.
assertFalse
(
self
.
_logger
.
logged_metric
)
if
global_step_val
>
warm_steps
:
self
.
_assert_metrics
()
else
:
# Nothing should be in the list yet
self
.
assertFalse
(
self
.
_logger
.
logged_metric
)
# Add additional run to verify proper reset when called multiple times.
prev_log_len
=
len
(
self
.
_logger
.
logged_metric
)
mon_sess
.
run
(
self
.
train_op
)
global_step_val
=
sess
.
run
(
self
.
global_step
)
if
every_n_steps
==
1
and
global_step_val
>
warm_steps
:
# Each time, we log two additional metrics. Did exactly 2 get added?
self
.
assertEqual
(
len
(
self
.
_logger
.
logged_metric
),
prev_log_len
+
2
)
else
:
# No change in the size of the metric list.
self
.
assertEqual
(
len
(
self
.
_logger
.
logged_metric
),
prev_log_len
)
# Add additional run to verify proper reset when called multiple times.
prev_log_len
=
len
(
self
.
_logger
.
logged_metric
)
mon_sess
.
run
(
self
.
train_op
)
global_step_val
=
mon_sess
.
run
(
self
.
global_step
)
hook
.
end
(
sess
)
if
every_n_steps
==
1
and
global_step_val
>
warm_steps
:
# Each time, we log two additional metrics. Did exactly 2 get added?
self
.
assertEqual
(
len
(
self
.
_logger
.
logged_metric
),
prev_log_len
+
2
)
else
:
# No change in the size of the metric list.
self
.
assertEqual
(
len
(
self
.
_logger
.
logged_metric
),
prev_log_len
)
def
test_examples_per_sec_every_1_steps
(
self
):
with
self
.
graph
.
as_default
()
,
tf
.
Session
()
as
sess
:
self
.
_validate_log_every_n_steps
(
sess
,
1
,
0
)
with
self
.
graph
.
as_default
():
self
.
_validate_log_every_n_steps
(
1
,
0
)
def
test_examples_per_sec_every_5_steps
(
self
):
with
self
.
graph
.
as_default
()
,
tf
.
Session
()
as
sess
:
self
.
_validate_log_every_n_steps
(
sess
,
5
,
0
)
with
self
.
graph
.
as_default
():
self
.
_validate_log_every_n_steps
(
5
,
0
)
def
test_examples_per_sec_every_1_steps_with_warm_steps
(
self
):
with
self
.
graph
.
as_default
()
,
tf
.
Session
()
as
sess
:
self
.
_validate_log_every_n_steps
(
sess
,
1
,
10
)
with
self
.
graph
.
as_default
():
self
.
_validate_log_every_n_steps
(
1
,
10
)
def
test_examples_per_sec_every_5_steps_with_warm_steps
(
self
):
with
self
.
graph
.
as_default
()
,
tf
.
Session
()
as
sess
:
self
.
_validate_log_every_n_steps
(
sess
,
5
,
10
)
with
self
.
graph
.
as_default
():
self
.
_validate_log_every_n_steps
(
5
,
10
)
def
_validate_log_every_n_secs
(
self
,
sess
,
every_n_secs
):
def
_validate_log_every_n_secs
(
self
,
every_n_secs
):
hook
=
hooks
.
ExamplesPerSecondHook
(
batch_size
=
256
,
every_n_steps
=
None
,
every_n_secs
=
every_n_secs
,
metric_logger
=
self
.
_logger
)
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
global_variables_initializer
())
mon_sess
.
run
(
self
.
train_op
)
# Nothing should be in the list yet
self
.
assertFalse
(
self
.
_logger
.
logged_metric
)
time
.
sleep
(
every_n_secs
)
mon_sess
.
run
(
self
.
train_op
)
self
.
_assert_metrics
()
with
tf
.
train
.
MonitoredSession
(
tf
.
train
.
ChiefSessionCreator
(),
[
hook
])
as
mon_sess
:
# Explicitly run global_step after train_op to get the accurate
# global_step value
mon_sess
.
run
(
self
.
train_op
)
mon_sess
.
run
(
self
.
global_step
)
# Nothing should be in the list yet
self
.
assertFalse
(
self
.
_logger
.
logged_metric
)
time
.
sleep
(
every_n_secs
)
hook
.
end
(
sess
)
mon_sess
.
run
(
self
.
train_op
)
mon_sess
.
run
(
self
.
global_step
)
self
.
_assert_metrics
()
def
test_examples_per_sec_every_1_secs
(
self
):
with
self
.
graph
.
as_default
()
,
tf
.
Session
()
as
sess
:
self
.
_validate_log_every_n_secs
(
sess
,
1
)
with
self
.
graph
.
as_default
():
self
.
_validate_log_every_n_secs
(
1
)
def
test_examples_per_sec_every_5_secs
(
self
):
with
self
.
graph
.
as_default
()
,
tf
.
Session
()
as
sess
:
self
.
_validate_log_every_n_secs
(
sess
,
5
)
with
self
.
graph
.
as_default
():
self
.
_validate_log_every_n_secs
(
5
)
def
_assert_metrics
(
self
):
metrics
=
self
.
_logger
.
logged_metric
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment