Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b2c9e3f5
Commit
b2c9e3f5
authored
Feb 08, 2019
by
Goldie Gadde
Committed by
Toby Boyd
Feb 08, 2019
Browse files
Revert "Revert "tf_upgrade_v2 on resnet and utils folders. (#6154)" (#6162)" (#6167)
This reverts commit
57e07520
.
parent
57e07520
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
121 additions
and
114 deletions
+121
-114
official/utils/flags/_device.py
official/utils/flags/_device.py
+1
-1
official/utils/logs/hooks.py
official/utils/logs/hooks.py
+4
-4
official/utils/logs/hooks_helper.py
official/utils/logs/hooks_helper.py
+3
-3
official/utils/logs/hooks_helper_test.py
official/utils/logs/hooks_helper_test.py
+1
-1
official/utils/logs/hooks_test.py
official/utils/logs/hooks_test.py
+9
-8
official/utils/logs/logger.py
official/utils/logs/logger.py
+24
-20
official/utils/logs/logger_test.py
official/utils/logs/logger_test.py
+22
-20
official/utils/logs/metric_hook.py
official/utils/logs/metric_hook.py
+2
-2
official/utils/logs/metric_hook_test.py
official/utils/logs/metric_hook_test.py
+17
-17
official/utils/logs/mlperf_helper.py
official/utils/logs/mlperf_helper.py
+2
-2
official/utils/misc/model_helpers.py
official/utils/misc/model_helpers.py
+4
-4
official/utils/misc/model_helpers_test.py
official/utils/misc/model_helpers_test.py
+9
-9
official/utils/testing/reference_data.py
official/utils/testing/reference_data.py
+15
-15
official/utils/testing/reference_data_test.py
official/utils/testing/reference_data_test.py
+8
-8
No files found.
official/utils/flags/_device.py
View file @
b2c9e3f5
...
@@ -39,7 +39,7 @@ def require_cloud_storage(flag_names):
...
@@ -39,7 +39,7 @@ def require_cloud_storage(flag_names):
valid_flags
=
True
valid_flags
=
True
for
key
in
flag_names
:
for
key
in
flag_names
:
if
not
flag_values
[
key
].
startswith
(
"gs://"
):
if
not
flag_values
[
key
].
startswith
(
"gs://"
):
tf
.
logging
.
error
(
"{} must be a GCS path."
.
format
(
key
))
tf
.
compat
.
v1
.
logging
.
error
(
"{} must be a GCS path."
.
format
(
key
))
valid_flags
=
False
valid_flags
=
False
return
valid_flags
return
valid_flags
...
...
official/utils/logs/hooks.py
View file @
b2c9e3f5
...
@@ -25,7 +25,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
...
@@ -25,7 +25,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
class
ExamplesPerSecondHook
(
tf
.
train
.
SessionRunHook
):
class
ExamplesPerSecondHook
(
tf
.
estimator
.
SessionRunHook
):
"""Hook to print out examples per second.
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
Total time is tracked and then divided by the total number of steps
...
@@ -66,7 +66,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
...
@@ -66,7 +66,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
self
.
_logger
=
metric_logger
or
logger
.
BaseBenchmarkLogger
()
self
.
_logger
=
metric_logger
or
logger
.
BaseBenchmarkLogger
()
self
.
_timer
=
tf
.
train
.
SecondOrStepTimer
(
self
.
_timer
=
tf
.
estimator
.
SecondOrStepTimer
(
every_steps
=
every_n_steps
,
every_secs
=
every_n_secs
)
every_steps
=
every_n_steps
,
every_secs
=
every_n_secs
)
self
.
_step_train_time
=
0
self
.
_step_train_time
=
0
...
@@ -76,7 +76,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
...
@@ -76,7 +76,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
def
begin
(
self
):
def
begin
(
self
):
"""Called once before using the session to check global step."""
"""Called once before using the session to check global step."""
self
.
_global_step_tensor
=
tf
.
train
.
get_global_step
()
self
.
_global_step_tensor
=
tf
.
compat
.
v1
.
train
.
get_global_step
()
if
self
.
_global_step_tensor
is
None
:
if
self
.
_global_step_tensor
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
"Global step should be created to use StepCounterHook."
)
"Global step should be created to use StepCounterHook."
)
...
@@ -90,7 +90,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
...
@@ -90,7 +90,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
Returns:
Returns:
A SessionRunArgs object or None if never triggered.
A SessionRunArgs object or None if never triggered.
"""
"""
return
tf
.
train
.
SessionRunArgs
(
self
.
_global_step_tensor
)
return
tf
.
estimator
.
SessionRunArgs
(
self
.
_global_step_tensor
)
def
after_run
(
self
,
run_context
,
run_values
):
# pylint: disable=unused-argument
def
after_run
(
self
,
run_context
,
run_values
):
# pylint: disable=unused-argument
"""Called after each call to run().
"""Called after each call to run().
...
...
official/utils/logs/hooks_helper.py
View file @
b2c9e3f5
...
@@ -57,7 +57,7 @@ def get_train_hooks(name_list, use_tpu=False, **kwargs):
...
@@ -57,7 +57,7 @@ def get_train_hooks(name_list, use_tpu=False, **kwargs):
return
[]
return
[]
if
use_tpu
:
if
use_tpu
:
tf
.
logging
.
warning
(
"hooks_helper received name_list `{}`, but a TPU is "
tf
.
compat
.
v1
.
logging
.
warning
(
"hooks_helper received name_list `{}`, but a TPU is "
"specified. No hooks will be used."
.
format
(
name_list
))
"specified. No hooks will be used."
.
format
(
name_list
))
return
[]
return
[]
...
@@ -89,7 +89,7 @@ def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): #
...
@@ -89,7 +89,7 @@ def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): #
if
tensors_to_log
is
None
:
if
tensors_to_log
is
None
:
tensors_to_log
=
_TENSORS_TO_LOG
tensors_to_log
=
_TENSORS_TO_LOG
return
tf
.
train
.
LoggingTensorHook
(
return
tf
.
estimator
.
LoggingTensorHook
(
tensors
=
tensors_to_log
,
tensors
=
tensors_to_log
,
every_n_iter
=
every_n_iter
)
every_n_iter
=
every_n_iter
)
...
@@ -106,7 +106,7 @@ def get_profiler_hook(model_dir, save_steps=1000, **kwargs): # pylint: disable=
...
@@ -106,7 +106,7 @@ def get_profiler_hook(model_dir, save_steps=1000, **kwargs): # pylint: disable=
Returns a ProfilerHook that writes out timelines that can be loaded into
Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing.
profiling tools like chrome://tracing.
"""
"""
return
tf
.
train
.
ProfilerHook
(
save_steps
=
save_steps
,
output_dir
=
model_dir
)
return
tf
.
estimator
.
ProfilerHook
(
save_steps
=
save_steps
,
output_dir
=
model_dir
)
def
get_examples_per_second_hook
(
every_n_steps
=
100
,
def
get_examples_per_second_hook
(
every_n_steps
=
100
,
...
...
official/utils/logs/hooks_helper_test.py
View file @
b2c9e3f5
...
@@ -45,7 +45,7 @@ class BaseTest(unittest.TestCase):
...
@@ -45,7 +45,7 @@ class BaseTest(unittest.TestCase):
returned_hook
=
hooks_helper
.
get_train_hooks
(
returned_hook
=
hooks_helper
.
get_train_hooks
(
[
test_hook_name
],
model_dir
=
""
,
**
kwargs
)
[
test_hook_name
],
model_dir
=
""
,
**
kwargs
)
self
.
assertEqual
(
len
(
returned_hook
),
1
)
self
.
assertEqual
(
len
(
returned_hook
),
1
)
self
.
assertIsInstance
(
returned_hook
[
0
],
tf
.
train
.
SessionRunHook
)
self
.
assertIsInstance
(
returned_hook
[
0
],
tf
.
estimator
.
SessionRunHook
)
self
.
assertEqual
(
returned_hook
[
0
].
__class__
.
__name__
.
lower
(),
self
.
assertEqual
(
returned_hook
[
0
].
__class__
.
__name__
.
lower
(),
expected_hook_name
)
expected_hook_name
)
...
...
official/utils/logs/hooks_test.py
View file @
b2c9e3f5
...
@@ -26,7 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
...
@@ -26,7 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from
official.utils.logs
import
hooks
from
official.utils.logs
import
hooks
from
official.utils.testing
import
mock_lib
from
official.utils.testing
import
mock_lib
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
DEBUG
)
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
DEBUG
)
class
ExamplesPerSecondHookTest
(
tf
.
test
.
TestCase
):
class
ExamplesPerSecondHookTest
(
tf
.
test
.
TestCase
):
...
@@ -44,9 +44,10 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
...
@@ -44,9 +44,10 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
self
.
graph
=
tf
.
Graph
()
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
tf
.
train
.
create_global_step
()
tf
.
compat
.
v1
.
train
.
create_global_step
()
self
.
train_op
=
tf
.
assign_add
(
tf
.
train
.
get_global_step
(),
1
)
self
.
train_op
=
tf
.
compat
.
v1
.
assign_add
(
self
.
global_step
=
tf
.
train
.
get_global_step
()
tf
.
compat
.
v1
.
train
.
get_global_step
(),
1
)
self
.
global_step
=
tf
.
compat
.
v1
.
train
.
get_global_step
()
def
test_raise_in_both_secs_and_steps
(
self
):
def
test_raise_in_both_secs_and_steps
(
self
):
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
...
@@ -71,8 +72,8 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
...
@@ -71,8 +72,8 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
warm_steps
=
warm_steps
,
warm_steps
=
warm_steps
,
metric_logger
=
self
.
_logger
)
metric_logger
=
self
.
_logger
)
with
tf
.
train
.
MonitoredSession
(
with
tf
.
compat
.
v1
.
train
.
MonitoredSession
(
tf
.
train
.
ChiefSessionCreator
(),
[
hook
])
as
mon_sess
:
tf
.
compat
.
v1
.
train
.
ChiefSessionCreator
(),
[
hook
])
as
mon_sess
:
for
_
in
range
(
every_n_steps
):
for
_
in
range
(
every_n_steps
):
# Explicitly run global_step after train_op to get the accurate
# Explicitly run global_step after train_op to get the accurate
# global_step value
# global_step value
...
@@ -125,8 +126,8 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
...
@@ -125,8 +126,8 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
every_n_secs
=
every_n_secs
,
every_n_secs
=
every_n_secs
,
metric_logger
=
self
.
_logger
)
metric_logger
=
self
.
_logger
)
with
tf
.
train
.
MonitoredSession
(
with
tf
.
compat
.
v1
.
train
.
MonitoredSession
(
tf
.
train
.
ChiefSessionCreator
(),
[
hook
])
as
mon_sess
:
tf
.
compat
.
v1
.
train
.
ChiefSessionCreator
(),
[
hook
])
as
mon_sess
:
# Explicitly run global_step after train_op to get the accurate
# Explicitly run global_step after train_op to get the accurate
# global_step value
# global_step value
mon_sess
.
run
(
self
.
train_op
)
mon_sess
.
run
(
self
.
train_op
)
...
...
official/utils/logs/logger.py
View file @
b2c9e3f5
...
@@ -119,12 +119,13 @@ class BaseBenchmarkLogger(object):
...
@@ -119,12 +119,13 @@ class BaseBenchmarkLogger(object):
eval_results: dict, the result of evaluate.
eval_results: dict, the result of evaluate.
"""
"""
if
not
isinstance
(
eval_results
,
dict
):
if
not
isinstance
(
eval_results
,
dict
):
tf
.
logging
.
warning
(
"eval_results should be dictionary for logging. "
tf
.
compat
.
v1
.
logging
.
warning
(
"Got %s"
,
type
(
eval_results
))
"eval_results should be dictionary for logging. Got %s"
,
type
(
eval_results
))
return
return
global_step
=
eval_results
[
tf
.
GraphKeys
.
GLOBAL_STEP
]
global_step
=
eval_results
[
tf
.
compat
.
v1
.
GraphKeys
.
GLOBAL_STEP
]
for
key
in
sorted
(
eval_results
):
for
key
in
sorted
(
eval_results
):
if
key
!=
tf
.
GraphKeys
.
GLOBAL_STEP
:
if
key
!=
tf
.
compat
.
v1
.
GraphKeys
.
GLOBAL_STEP
:
self
.
log_metric
(
key
,
eval_results
[
key
],
global_step
=
global_step
)
self
.
log_metric
(
key
,
eval_results
[
key
],
global_step
=
global_step
)
def
log_metric
(
self
,
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
def
log_metric
(
self
,
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
...
@@ -143,12 +144,12 @@ class BaseBenchmarkLogger(object):
...
@@ -143,12 +144,12 @@ class BaseBenchmarkLogger(object):
"""
"""
metric
=
_process_metric_to_json
(
name
,
value
,
unit
,
global_step
,
extras
)
metric
=
_process_metric_to_json
(
name
,
value
,
unit
,
global_step
,
extras
)
if
metric
:
if
metric
:
tf
.
logging
.
info
(
"Benchmark metric: %s"
,
metric
)
tf
.
compat
.
v1
.
logging
.
info
(
"Benchmark metric: %s"
,
metric
)
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
,
test_id
=
None
):
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
,
test_id
=
None
):
tf
.
logging
.
info
(
"Benchmark run: %s"
,
tf
.
compat
.
v1
.
logging
.
info
(
_gather_run_info
(
model_name
,
dataset_name
,
run_params
,
"Benchmark run: %s"
,
_gather_run_info
(
model_name
,
dataset_name
,
test_id
))
run_params
,
test_id
))
def
on_finish
(
self
,
status
):
def
on_finish
(
self
,
status
):
pass
pass
...
@@ -160,9 +161,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
...
@@ -160,9 +161,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
def
__init__
(
self
,
logging_dir
):
def
__init__
(
self
,
logging_dir
):
super
(
BenchmarkFileLogger
,
self
).
__init__
()
super
(
BenchmarkFileLogger
,
self
).
__init__
()
self
.
_logging_dir
=
logging_dir
self
.
_logging_dir
=
logging_dir
if
not
tf
.
gfile
.
IsDirectory
(
self
.
_logging_dir
):
if
not
tf
.
io
.
gfile
.
isdir
(
self
.
_logging_dir
):
tf
.
gfile
.
M
ake
D
irs
(
self
.
_logging_dir
)
tf
.
io
.
gfile
.
m
ake
d
irs
(
self
.
_logging_dir
)
self
.
_metric_file_handler
=
tf
.
gfile
.
GFile
(
self
.
_metric_file_handler
=
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
self
.
_logging_dir
,
METRIC_LOG_FILE_NAME
),
"a"
)
os
.
path
.
join
(
self
.
_logging_dir
,
METRIC_LOG_FILE_NAME
),
"a"
)
def
log_metric
(
self
,
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
def
log_metric
(
self
,
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
...
@@ -186,8 +187,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
...
@@ -186,8 +187,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
self
.
_metric_file_handler
.
write
(
"
\n
"
)
self
.
_metric_file_handler
.
write
(
"
\n
"
)
self
.
_metric_file_handler
.
flush
()
self
.
_metric_file_handler
.
flush
()
except
(
TypeError
,
ValueError
)
as
e
:
except
(
TypeError
,
ValueError
)
as
e
:
tf
.
logging
.
warning
(
"Failed to dump metric to log file: "
tf
.
compat
.
v1
.
logging
.
warning
(
"name %s, value %s, error %s"
,
name
,
value
,
e
)
"Failed to dump metric to log file: name %s, value %s, error %s"
,
name
,
value
,
e
)
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
,
test_id
=
None
):
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
,
test_id
=
None
):
"""Collect most of the TF runtime information for the local env.
"""Collect most of the TF runtime information for the local env.
...
@@ -204,14 +206,14 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
...
@@ -204,14 +206,14 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
"""
"""
run_info
=
_gather_run_info
(
model_name
,
dataset_name
,
run_params
,
test_id
)
run_info
=
_gather_run_info
(
model_name
,
dataset_name
,
run_params
,
test_id
)
with
tf
.
gfile
.
GFile
(
os
.
path
.
join
(
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
self
.
_logging_dir
,
BENCHMARK_RUN_LOG_FILE_NAME
),
"w"
)
as
f
:
self
.
_logging_dir
,
BENCHMARK_RUN_LOG_FILE_NAME
),
"w"
)
as
f
:
try
:
try
:
json
.
dump
(
run_info
,
f
)
json
.
dump
(
run_info
,
f
)
f
.
write
(
"
\n
"
)
f
.
write
(
"
\n
"
)
except
(
TypeError
,
ValueError
)
as
e
:
except
(
TypeError
,
ValueError
)
as
e
:
tf
.
logging
.
warning
(
"Failed to dump benchmark run info to log file: %s"
,
tf
.
compat
.
v1
.
logging
.
warning
(
e
)
"Failed to dump benchmark run info to log file: %s"
,
e
)
def
on_finish
(
self
,
status
):
def
on_finish
(
self
,
status
):
self
.
_metric_file_handler
.
flush
()
self
.
_metric_file_handler
.
flush
()
...
@@ -324,7 +326,7 @@ def _process_metric_to_json(
...
@@ -324,7 +326,7 @@ def _process_metric_to_json(
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
"""Validate the metric data and generate JSON for insert."""
"""Validate the metric data and generate JSON for insert."""
if
not
isinstance
(
value
,
numbers
.
Number
):
if
not
isinstance
(
value
,
numbers
.
Number
):
tf
.
logging
.
warning
(
tf
.
compat
.
v1
.
logging
.
warning
(
"Metric value to log should be a number. Got %s"
,
type
(
value
))
"Metric value to log should be a number. Got %s"
,
type
(
value
))
return
None
return
None
...
@@ -341,7 +343,7 @@ def _process_metric_to_json(
...
@@ -341,7 +343,7 @@ def _process_metric_to_json(
def
_collect_tensorflow_info
(
run_info
):
def
_collect_tensorflow_info
(
run_info
):
run_info
[
"tensorflow_version"
]
=
{
run_info
[
"tensorflow_version"
]
=
{
"version"
:
tf
.
VERSION
,
"git_hash"
:
tf
.
GIT_VERSION
}
"version"
:
tf
.
version
.
VERSION
,
"git_hash"
:
tf
.
version
.
GIT_VERSION
}
def
_collect_run_params
(
run_info
,
run_params
):
def
_collect_run_params
(
run_info
,
run_params
):
...
@@ -385,7 +387,8 @@ def _collect_cpu_info(run_info):
...
@@ -385,7 +387,8 @@ def _collect_cpu_info(run_info):
run_info
[
"machine_config"
][
"cpu_info"
]
=
cpu_info
run_info
[
"machine_config"
][
"cpu_info"
]
=
cpu_info
except
ImportError
:
except
ImportError
:
tf
.
logging
.
warn
(
"'cpuinfo' not imported. CPU info will not be logged."
)
tf
.
compat
.
v1
.
logging
.
warn
(
"'cpuinfo' not imported. CPU info will not be logged."
)
def
_collect_gpu_info
(
run_info
,
session_config
=
None
):
def
_collect_gpu_info
(
run_info
,
session_config
=
None
):
...
@@ -415,7 +418,8 @@ def _collect_memory_info(run_info):
...
@@ -415,7 +418,8 @@ def _collect_memory_info(run_info):
run_info
[
"machine_config"
][
"memory_total"
]
=
vmem
.
total
run_info
[
"machine_config"
][
"memory_total"
]
=
vmem
.
total
run_info
[
"machine_config"
][
"memory_available"
]
=
vmem
.
available
run_info
[
"machine_config"
][
"memory_available"
]
=
vmem
.
available
except
ImportError
:
except
ImportError
:
tf
.
logging
.
warn
(
"'psutil' not imported. Memory info will not be logged."
)
tf
.
compat
.
v1
.
logging
.
warn
(
"'psutil' not imported. Memory info will not be logged."
)
def
_collect_test_environment
(
run_info
):
def
_collect_test_environment
(
run_info
):
...
...
official/utils/logs/logger_test.py
View file @
b2c9e3f5
...
@@ -78,7 +78,7 @@ class BenchmarkLoggerTest(tf.test.TestCase):
...
@@ -78,7 +78,7 @@ class BenchmarkLoggerTest(tf.test.TestCase):
mock_logger
=
mock
.
MagicMock
()
mock_logger
=
mock
.
MagicMock
()
mock_config_benchmark_logger
.
return_value
=
mock_logger
mock_config_benchmark_logger
.
return_value
=
mock_logger
with
logger
.
benchmark_context
(
None
):
with
logger
.
benchmark_context
(
None
):
tf
.
logging
.
info
(
"start benchmarking"
)
tf
.
compat
.
v1
.
logging
.
info
(
"start benchmarking"
)
mock_logger
.
on_finish
.
assert_called_once_with
(
logger
.
RUN_STATUS_SUCCESS
)
mock_logger
.
on_finish
.
assert_called_once_with
(
logger
.
RUN_STATUS_SUCCESS
)
@
mock
.
patch
(
"official.utils.logs.logger.config_benchmark_logger"
)
@
mock
.
patch
(
"official.utils.logs.logger.config_benchmark_logger"
)
...
@@ -95,18 +95,18 @@ class BaseBenchmarkLoggerTest(tf.test.TestCase):
...
@@ -95,18 +95,18 @@ class BaseBenchmarkLoggerTest(tf.test.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
super
(
BaseBenchmarkLoggerTest
,
self
).
setUp
()
super
(
BaseBenchmarkLoggerTest
,
self
).
setUp
()
self
.
_actual_log
=
tf
.
logging
.
info
self
.
_actual_log
=
tf
.
compat
.
v1
.
logging
.
info
self
.
logged_message
=
None
self
.
logged_message
=
None
def
mock_log
(
*
args
,
**
kwargs
):
def
mock_log
(
*
args
,
**
kwargs
):
self
.
logged_message
=
args
self
.
logged_message
=
args
self
.
_actual_log
(
*
args
,
**
kwargs
)
self
.
_actual_log
(
*
args
,
**
kwargs
)
tf
.
logging
.
info
=
mock_log
tf
.
compat
.
v1
.
logging
.
info
=
mock_log
def
tearDown
(
self
):
def
tearDown
(
self
):
super
(
BaseBenchmarkLoggerTest
,
self
).
tearDown
()
super
(
BaseBenchmarkLoggerTest
,
self
).
tearDown
()
tf
.
logging
.
info
=
self
.
_actual_log
tf
.
compat
.
v1
.
logging
.
info
=
self
.
_actual_log
def
test_log_metric
(
self
):
def
test_log_metric
(
self
):
log
=
logger
.
BaseBenchmarkLogger
()
log
=
logger
.
BaseBenchmarkLogger
()
...
@@ -128,16 +128,16 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
...
@@ -128,16 +128,16 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
def
tearDown
(
self
):
def
tearDown
(
self
):
super
(
BenchmarkFileLoggerTest
,
self
).
tearDown
()
super
(
BenchmarkFileLoggerTest
,
self
).
tearDown
()
tf
.
gfile
.
DeleteRecursively
(
self
.
get_temp_dir
())
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
os
.
environ
.
clear
()
os
.
environ
.
clear
()
os
.
environ
.
update
(
self
.
original_environ
)
os
.
environ
.
update
(
self
.
original_environ
)
def
test_create_logging_dir
(
self
):
def
test_create_logging_dir
(
self
):
non_exist_temp_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"unknown_dir"
)
non_exist_temp_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"unknown_dir"
)
self
.
assertFalse
(
tf
.
gfile
.
IsDirectory
(
non_exist_temp_dir
))
self
.
assertFalse
(
tf
.
io
.
gfile
.
isdir
(
non_exist_temp_dir
))
logger
.
BenchmarkFileLogger
(
non_exist_temp_dir
)
logger
.
BenchmarkFileLogger
(
non_exist_temp_dir
)
self
.
assertTrue
(
tf
.
gfile
.
IsDirectory
(
non_exist_temp_dir
))
self
.
assertTrue
(
tf
.
io
.
gfile
.
isdir
(
non_exist_temp_dir
))
def
test_log_metric
(
self
):
def
test_log_metric
(
self
):
log_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
log_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
...
@@ -145,8 +145,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
...
@@ -145,8 +145,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log
.
log_metric
(
"accuracy"
,
0.999
,
global_step
=
1e4
,
extras
=
{
"name"
:
"value"
})
log
.
log_metric
(
"accuracy"
,
0.999
,
global_step
=
1e4
,
extras
=
{
"name"
:
"value"
})
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertTrue
(
tf
.
gfile
.
E
xists
(
metric_log
))
self
.
assertTrue
(
tf
.
io
.
gfile
.
e
xists
(
metric_log
))
with
tf
.
gfile
.
GFile
(
metric_log
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
metric_log
)
as
f
:
metric
=
json
.
loads
(
f
.
readline
())
metric
=
json
.
loads
(
f
.
readline
())
self
.
assertEqual
(
metric
[
"name"
],
"accuracy"
)
self
.
assertEqual
(
metric
[
"name"
],
"accuracy"
)
self
.
assertEqual
(
metric
[
"value"
],
0.999
)
self
.
assertEqual
(
metric
[
"value"
],
0.999
)
...
@@ -161,8 +161,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
...
@@ -161,8 +161,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log
.
log_metric
(
"loss"
,
0.02
,
global_step
=
1e4
)
log
.
log_metric
(
"loss"
,
0.02
,
global_step
=
1e4
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertTrue
(
tf
.
gfile
.
E
xists
(
metric_log
))
self
.
assertTrue
(
tf
.
io
.
gfile
.
e
xists
(
metric_log
))
with
tf
.
gfile
.
GFile
(
metric_log
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
metric_log
)
as
f
:
accuracy
=
json
.
loads
(
f
.
readline
())
accuracy
=
json
.
loads
(
f
.
readline
())
self
.
assertEqual
(
accuracy
[
"name"
],
"accuracy"
)
self
.
assertEqual
(
accuracy
[
"name"
],
"accuracy"
)
self
.
assertEqual
(
accuracy
[
"value"
],
0.999
)
self
.
assertEqual
(
accuracy
[
"value"
],
0.999
)
...
@@ -184,7 +184,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
...
@@ -184,7 +184,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log
.
log_metric
(
"accuracy"
,
const
)
log
.
log_metric
(
"accuracy"
,
const
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertFalse
(
tf
.
gfile
.
E
xists
(
metric_log
))
self
.
assertFalse
(
tf
.
io
.
gfile
.
e
xists
(
metric_log
))
def
test_log_evaluation_result
(
self
):
def
test_log_evaluation_result
(
self
):
eval_result
=
{
"loss"
:
0.46237424
,
eval_result
=
{
"loss"
:
0.46237424
,
...
@@ -195,8 +195,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
...
@@ -195,8 +195,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log
.
log_evaluation_result
(
eval_result
)
log
.
log_evaluation_result
(
eval_result
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertTrue
(
tf
.
gfile
.
E
xists
(
metric_log
))
self
.
assertTrue
(
tf
.
io
.
gfile
.
e
xists
(
metric_log
))
with
tf
.
gfile
.
GFile
(
metric_log
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
metric_log
)
as
f
:
accuracy
=
json
.
loads
(
f
.
readline
())
accuracy
=
json
.
loads
(
f
.
readline
())
self
.
assertEqual
(
accuracy
[
"name"
],
"accuracy"
)
self
.
assertEqual
(
accuracy
[
"name"
],
"accuracy"
)
self
.
assertEqual
(
accuracy
[
"value"
],
0.9285
)
self
.
assertEqual
(
accuracy
[
"value"
],
0.9285
)
...
@@ -216,7 +216,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
...
@@ -216,7 +216,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log
.
log_evaluation_result
(
eval_result
)
log
.
log_evaluation_result
(
eval_result
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertFalse
(
tf
.
gfile
.
E
xists
(
metric_log
))
self
.
assertFalse
(
tf
.
io
.
gfile
.
e
xists
(
metric_log
))
@
mock
.
patch
(
"official.utils.logs.logger._gather_run_info"
)
@
mock
.
patch
(
"official.utils.logs.logger._gather_run_info"
)
def
test_log_run_info
(
self
,
mock_gather_run_info
):
def
test_log_run_info
(
self
,
mock_gather_run_info
):
...
@@ -229,8 +229,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
...
@@ -229,8 +229,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log
.
log_run_info
(
"model_name"
,
"dataset_name"
,
{})
log
.
log_run_info
(
"model_name"
,
"dataset_name"
,
{})
run_log
=
os
.
path
.
join
(
log_dir
,
"benchmark_run.log"
)
run_log
=
os
.
path
.
join
(
log_dir
,
"benchmark_run.log"
)
self
.
assertTrue
(
tf
.
gfile
.
E
xists
(
run_log
))
self
.
assertTrue
(
tf
.
io
.
gfile
.
e
xists
(
run_log
))
with
tf
.
gfile
.
GFile
(
run_log
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
run_log
)
as
f
:
run_info
=
json
.
loads
(
f
.
readline
())
run_info
=
json
.
loads
(
f
.
readline
())
self
.
assertEqual
(
run_info
[
"model_name"
],
"model_name"
)
self
.
assertEqual
(
run_info
[
"model_name"
],
"model_name"
)
self
.
assertEqual
(
run_info
[
"dataset"
],
"dataset_name"
)
self
.
assertEqual
(
run_info
[
"dataset"
],
"dataset_name"
)
...
@@ -240,8 +240,10 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
...
@@ -240,8 +240,10 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
run_info
=
{}
run_info
=
{}
logger
.
_collect_tensorflow_info
(
run_info
)
logger
.
_collect_tensorflow_info
(
run_info
)
self
.
assertNotEqual
(
run_info
[
"tensorflow_version"
],
{})
self
.
assertNotEqual
(
run_info
[
"tensorflow_version"
],
{})
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"version"
],
tf
.
VERSION
)
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"version"
],
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"git_hash"
],
tf
.
GIT_VERSION
)
tf
.
version
.
VERSION
)
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"git_hash"
],
tf
.
version
.
GIT_VERSION
)
def
test_collect_run_params
(
self
):
def
test_collect_run_params
(
self
):
run_info
=
{}
run_info
=
{}
...
@@ -315,7 +317,7 @@ class BenchmarkBigQueryLoggerTest(tf.test.TestCase):
...
@@ -315,7 +317,7 @@ class BenchmarkBigQueryLoggerTest(tf.test.TestCase):
def
tearDown
(
self
):
def
tearDown
(
self
):
super
(
BenchmarkBigQueryLoggerTest
,
self
).
tearDown
()
super
(
BenchmarkBigQueryLoggerTest
,
self
).
tearDown
()
tf
.
gfile
.
DeleteRecursively
(
self
.
get_temp_dir
())
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
os
.
environ
.
clear
()
os
.
environ
.
clear
()
os
.
environ
.
update
(
self
.
original_environ
)
os
.
environ
.
update
(
self
.
original_environ
)
...
...
official/utils/logs/metric_hook.py
View file @
b2c9e3f5
...
@@ -21,7 +21,7 @@ from __future__ import print_function
...
@@ -21,7 +21,7 @@ from __future__ import print_function
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
class
LoggingMetricHook
(
tf
.
train
.
LoggingTensorHook
):
class
LoggingMetricHook
(
tf
.
estimator
.
LoggingTensorHook
):
"""Hook to log benchmark metric information.
"""Hook to log benchmark metric information.
This hook is very similar as tf.train.LoggingTensorHook, which logs given
This hook is very similar as tf.train.LoggingTensorHook, which logs given
...
@@ -68,7 +68,7 @@ class LoggingMetricHook(tf.train.LoggingTensorHook):
...
@@ -68,7 +68,7 @@ class LoggingMetricHook(tf.train.LoggingTensorHook):
def
begin
(
self
):
def
begin
(
self
):
super
(
LoggingMetricHook
,
self
).
begin
()
super
(
LoggingMetricHook
,
self
).
begin
()
self
.
_global_step_tensor
=
tf
.
train
.
get_global_step
()
self
.
_global_step_tensor
=
tf
.
compat
.
v1
.
train
.
get_global_step
()
if
self
.
_global_step_tensor
is
None
:
if
self
.
_global_step_tensor
is
None
:
raise
RuntimeError
(
raise
RuntimeError
(
"Global step should be created to use LoggingMetricHook."
)
"Global step should be created to use LoggingMetricHook."
)
...
...
official/utils/logs/metric_hook_test.py
View file @
b2c9e3f5
...
@@ -39,7 +39,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
...
@@ -39,7 +39,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
def
tearDown
(
self
):
def
tearDown
(
self
):
super
(
LoggingMetricHookTest
,
self
).
tearDown
()
super
(
LoggingMetricHookTest
,
self
).
tearDown
()
tf
.
gfile
.
DeleteRecursively
(
self
.
get_temp_dir
())
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
def
test_illegal_args
(
self
):
def
test_illegal_args
(
self
):
with
self
.
assertRaisesRegexp
(
ValueError
,
"nvalid every_n_iter"
):
with
self
.
assertRaisesRegexp
(
ValueError
,
"nvalid every_n_iter"
):
...
@@ -55,15 +55,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
...
@@ -55,15 +55,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
metric_hook
.
LoggingMetricHook
(
tensors
=
[
"t"
],
every_n_iter
=
5
)
metric_hook
.
LoggingMetricHook
(
tensors
=
[
"t"
],
every_n_iter
=
5
)
def
test_print_at_end_only
(
self
):
def
test_print_at_end_only
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
:
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
train
.
get_or_create_global_step
()
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
t
=
tf
.
constant
(
42.0
,
name
=
"foo"
)
t
=
tf
.
constant
(
42.0
,
name
=
"foo"
)
train_op
=
tf
.
constant
(
3
)
train_op
=
tf
.
constant
(
3
)
hook
=
metric_hook
.
LoggingMetricHook
(
hook
=
metric_hook
.
LoggingMetricHook
(
tensors
=
[
t
.
name
],
at_end
=
True
,
metric_logger
=
self
.
_logger
)
tensors
=
[
t
.
name
],
at_end
=
True
,
metric_logger
=
self
.
_logger
)
hook
.
begin
()
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
for
_
in
range
(
3
):
for
_
in
range
(
3
):
mon_sess
.
run
(
train_op
)
mon_sess
.
run
(
train_op
)
...
@@ -88,8 +88,8 @@ class LoggingMetricHookTest(tf.test.TestCase):
...
@@ -88,8 +88,8 @@ class LoggingMetricHookTest(tf.test.TestCase):
hook
.
begin
()
hook
.
begin
()
def
test_log_tensors
(
self
):
def
test_log_tensors
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
:
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
train
.
get_or_create_global_step
()
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
t1
=
tf
.
constant
(
42.0
,
name
=
"foo"
)
t1
=
tf
.
constant
(
42.0
,
name
=
"foo"
)
t2
=
tf
.
constant
(
43.0
,
name
=
"bar"
)
t2
=
tf
.
constant
(
43.0
,
name
=
"bar"
)
train_op
=
tf
.
constant
(
3
)
train_op
=
tf
.
constant
(
3
)
...
@@ -97,7 +97,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
...
@@ -97,7 +97,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
tensors
=
[
t1
,
t2
],
at_end
=
True
,
metric_logger
=
self
.
_logger
)
tensors
=
[
t1
,
t2
],
at_end
=
True
,
metric_logger
=
self
.
_logger
)
hook
.
begin
()
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
for
_
in
range
(
3
):
for
_
in
range
(
3
):
mon_sess
.
run
(
train_op
)
mon_sess
.
run
(
train_op
)
...
@@ -126,7 +126,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
...
@@ -126,7 +126,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
metric_logger
=
self
.
_logger
)
metric_logger
=
self
.
_logger
)
hook
.
begin
()
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
mon_sess
.
run
(
train_op
)
mon_sess
.
run
(
train_op
)
self
.
assertRegexpMatches
(
str
(
self
.
_logger
.
logged_metric
),
t
.
name
)
self
.
assertRegexpMatches
(
str
(
self
.
_logger
.
logged_metric
),
t
.
name
)
for
_
in
range
(
3
):
for
_
in
range
(
3
):
...
@@ -153,15 +153,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
...
@@ -153,15 +153,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
self
.
assertEqual
(
str
(
self
.
_logger
.
logged_metric
).
find
(
t
.
name
),
-
1
)
self
.
assertEqual
(
str
(
self
.
_logger
.
logged_metric
).
find
(
t
.
name
),
-
1
)
def
test_print_every_n_steps
(
self
):
def
test_print_every_n_steps
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
:
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
train
.
get_or_create_global_step
()
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
False
)
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
False
)
# Verify proper reset.
# Verify proper reset.
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
False
)
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
False
)
def
test_print_every_n_steps_and_end
(
self
):
def
test_print_every_n_steps_and_end
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
:
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
train
.
get_or_create_global_step
()
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
True
)
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
True
)
# Verify proper reset.
# Verify proper reset.
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
True
)
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
True
)
...
@@ -175,7 +175,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
...
@@ -175,7 +175,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
metric_logger
=
self
.
_logger
)
metric_logger
=
self
.
_logger
)
hook
.
begin
()
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
mon_sess
.
run
(
train_op
)
mon_sess
.
run
(
train_op
)
self
.
assertRegexpMatches
(
str
(
self
.
_logger
.
logged_metric
),
t
.
name
)
self
.
assertRegexpMatches
(
str
(
self
.
_logger
.
logged_metric
),
t
.
name
)
...
@@ -199,15 +199,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
...
@@ -199,15 +199,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
self
.
assertEqual
(
str
(
self
.
_logger
.
logged_metric
).
find
(
t
.
name
),
-
1
)
self
.
assertEqual
(
str
(
self
.
_logger
.
logged_metric
).
find
(
t
.
name
),
-
1
)
def
test_print_every_n_secs
(
self
):
def
test_print_every_n_secs
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
:
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
train
.
get_or_create_global_step
()
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
False
)
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
False
)
# Verify proper reset.
# Verify proper reset.
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
False
)
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
False
)
def
test_print_every_n_secs_and_end
(
self
):
def
test_print_every_n_secs_and_end
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
:
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
train
.
get_or_create_global_step
()
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
True
)
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
True
)
# Verify proper reset.
# Verify proper reset.
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
True
)
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
True
)
...
...
official/utils/logs/mlperf_helper.py
View file @
b2c9e3f5
...
@@ -94,7 +94,7 @@ def get_mlperf_log():
...
@@ -94,7 +94,7 @@ def get_mlperf_log():
version
=
pkg_resources
.
get_distribution
(
"mlperf_compliance"
)
version
=
pkg_resources
.
get_distribution
(
"mlperf_compliance"
)
version
=
tuple
(
int
(
i
)
for
i
in
version
.
version
.
split
(
"."
))
version
=
tuple
(
int
(
i
)
for
i
in
version
.
version
.
split
(
"."
))
if
version
<
_MIN_VERSION
:
if
version
<
_MIN_VERSION
:
tf
.
logging
.
warning
(
tf
.
compat
.
v1
.
logging
.
warning
(
"mlperf_compliance is version {}, must be >= {}"
.
format
(
"mlperf_compliance is version {}, must be >= {}"
.
format
(
"."
.
join
([
str
(
i
)
for
i
in
version
]),
"."
.
join
([
str
(
i
)
for
i
in
version
]),
"."
.
join
([
str
(
i
)
for
i
in
_MIN_VERSION
])))
"."
.
join
([
str
(
i
)
for
i
in
_MIN_VERSION
])))
...
@@ -187,6 +187,6 @@ def clear_system_caches():
...
@@ -187,6 +187,6 @@ def clear_system_caches():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
with
LOGGER
(
True
):
with
LOGGER
(
True
):
ncf_print
(
key
=
TAGS
.
RUN_START
)
ncf_print
(
key
=
TAGS
.
RUN_START
)
official/utils/misc/model_helpers.py
View file @
b2c9e3f5
...
@@ -48,7 +48,7 @@ def past_stop_threshold(stop_threshold, eval_metric):
...
@@ -48,7 +48,7 @@ def past_stop_threshold(stop_threshold, eval_metric):
"must be a number."
)
"must be a number."
)
if
eval_metric
>=
stop_threshold
:
if
eval_metric
>=
stop_threshold
:
tf
.
logging
.
info
(
tf
.
compat
.
v1
.
logging
.
info
(
"Stop threshold of {} was passed with metric value {}."
.
format
(
"Stop threshold of {} was passed with metric value {}."
.
format
(
stop_threshold
,
eval_metric
))
stop_threshold
,
eval_metric
))
return
True
return
True
...
@@ -87,7 +87,7 @@ def generate_synthetic_data(
...
@@ -87,7 +87,7 @@ def generate_synthetic_data(
def
apply_clean
(
flags_obj
):
def
apply_clean
(
flags_obj
):
if
flags_obj
.
clean
and
tf
.
gfile
.
E
xists
(
flags_obj
.
model_dir
):
if
flags_obj
.
clean
and
tf
.
io
.
gfile
.
e
xists
(
flags_obj
.
model_dir
):
tf
.
logging
.
info
(
"--clean flag set. Removing existing model dir: {}"
.
format
(
tf
.
compat
.
v1
.
logging
.
info
(
"--clean flag set. Removing existing model dir: {}"
.
format
(
flags_obj
.
model_dir
))
flags_obj
.
model_dir
))
tf
.
gfile
.
DeleteRecursively
(
flags_obj
.
model_dir
)
tf
.
io
.
gfile
.
rmtree
(
flags_obj
.
model_dir
)
official/utils/misc/model_helpers_test.py
View file @
b2c9e3f5
...
@@ -69,13 +69,13 @@ class SyntheticDataTest(tf.test.TestCase):
...
@@ -69,13 +69,13 @@ class SyntheticDataTest(tf.test.TestCase):
"""Tests for generate_synthetic_data."""
"""Tests for generate_synthetic_data."""
def
test_generate_synethetic_data
(
self
):
def
test_generate_synethetic_data
(
self
):
input_element
,
label_element
=
model_helpers
.
generate_synthetic_data
(
input_element
,
label_element
=
tf
.
compat
.
v1
.
data
.
make_one_shot_iterator
(
input_shape
=
tf
.
TensorShape
([
5
]),
model_helpers
.
generate_synthetic_data
(
input_shape
=
tf
.
TensorShape
([
5
]),
input_value
=
123
,
input_value
=
123
,
input_dtype
=
tf
.
float32
,
input_dtype
=
tf
.
float32
,
label_shape
=
tf
.
TensorShape
([]),
label_shape
=
tf
.
TensorShape
([]),
label_value
=
456
,
label_value
=
456
,
label_dtype
=
tf
.
int32
)
.
make_one_shot_iterator
(
).
get_next
()
label_dtype
=
tf
.
int32
)).
get_next
()
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
for
n
in
range
(
5
):
for
n
in
range
(
5
):
...
@@ -89,7 +89,7 @@ class SyntheticDataTest(tf.test.TestCase):
...
@@ -89,7 +89,7 @@ class SyntheticDataTest(tf.test.TestCase):
input_value
=
43.5
,
input_value
=
43.5
,
input_dtype
=
tf
.
float32
)
input_dtype
=
tf
.
float32
)
element
=
d
.
make_one_shot_iterator
().
get_next
()
element
=
tf
.
compat
.
v1
.
data
.
make_one_shot_iterator
(
d
).
get_next
()
self
.
assertFalse
(
isinstance
(
element
,
tuple
))
self
.
assertFalse
(
isinstance
(
element
,
tuple
))
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
...
@@ -102,7 +102,7 @@ class SyntheticDataTest(tf.test.TestCase):
...
@@ -102,7 +102,7 @@ class SyntheticDataTest(tf.test.TestCase):
'b'
:
{
'c'
:
tf
.
TensorShape
([
3
]),
'd'
:
tf
.
TensorShape
([])}},
'b'
:
{
'c'
:
tf
.
TensorShape
([
3
]),
'd'
:
tf
.
TensorShape
([])}},
input_value
=
1.1
)
input_value
=
1.1
)
element
=
d
.
make_one_shot_iterator
().
get_next
()
element
=
tf
.
compat
.
v1
.
data
.
make_one_shot_iterator
(
d
).
get_next
()
self
.
assertIn
(
'a'
,
element
)
self
.
assertIn
(
'a'
,
element
)
self
.
assertIn
(
'b'
,
element
)
self
.
assertIn
(
'b'
,
element
)
self
.
assertEquals
(
len
(
element
[
'b'
]),
2
)
self
.
assertEquals
(
len
(
element
[
'b'
]),
2
)
...
...
official/utils/testing/reference_data.py
View file @
b2c9e3f5
...
@@ -170,12 +170,12 @@ class BaseTest(tf.test.TestCase):
...
@@ -170,12 +170,12 @@ class BaseTest(tf.test.TestCase):
# Serialize graph for comparison.
# Serialize graph for comparison.
graph_bytes
=
graph
.
as_graph_def
().
SerializeToString
()
graph_bytes
=
graph
.
as_graph_def
().
SerializeToString
()
expected_file
=
os
.
path
.
join
(
data_dir
,
"expected_graph"
)
expected_file
=
os
.
path
.
join
(
data_dir
,
"expected_graph"
)
with
tf
.
gfile
.
Open
(
expected_file
,
"wb"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
expected_file
,
"wb"
)
as
f
:
f
.
write
(
graph_bytes
)
f
.
write
(
graph_bytes
)
with
graph
.
as_default
():
with
graph
.
as_default
():
init
=
tf
.
global_variables_initializer
()
init
=
tf
.
compat
.
v1
.
global_variables_initializer
()
saver
=
tf
.
train
.
Saver
()
saver
=
tf
.
compat
.
v1
.
train
.
Saver
()
with
self
.
test_session
(
graph
=
graph
)
as
sess
:
with
self
.
test_session
(
graph
=
graph
)
as
sess
:
sess
.
run
(
init
)
sess
.
run
(
init
)
...
@@ -191,11 +191,11 @@ class BaseTest(tf.test.TestCase):
...
@@ -191,11 +191,11 @@ class BaseTest(tf.test.TestCase):
if
correctness_function
is
not
None
:
if
correctness_function
is
not
None
:
results
=
correctness_function
(
*
eval_results
)
results
=
correctness_function
(
*
eval_results
)
with
tf
.
gfile
.
Open
(
os
.
path
.
join
(
data_dir
,
"results.json"
),
"w"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
data_dir
,
"results.json"
),
"w"
)
as
f
:
json
.
dump
(
results
,
f
)
json
.
dump
(
results
,
f
)
with
tf
.
gfile
.
Open
(
os
.
path
.
join
(
data_dir
,
"tf_version.json"
),
"w"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
data_dir
,
"tf_version.json"
),
"w"
)
as
f
:
json
.
dump
([
tf
.
VERSION
,
tf
.
GIT_VERSION
],
f
)
json
.
dump
([
tf
.
version
.
VERSION
,
tf
.
version
.
GIT_VERSION
],
f
)
def
_evaluate_test_case
(
self
,
name
,
graph
,
ops_to_eval
,
correctness_function
):
def
_evaluate_test_case
(
self
,
name
,
graph
,
ops_to_eval
,
correctness_function
):
"""Determine if a graph agrees with the reference data.
"""Determine if a graph agrees with the reference data.
...
@@ -216,7 +216,7 @@ class BaseTest(tf.test.TestCase):
...
@@ -216,7 +216,7 @@ class BaseTest(tf.test.TestCase):
# Serialize graph for comparison.
# Serialize graph for comparison.
graph_bytes
=
graph
.
as_graph_def
().
SerializeToString
()
graph_bytes
=
graph
.
as_graph_def
().
SerializeToString
()
expected_file
=
os
.
path
.
join
(
data_dir
,
"expected_graph"
)
expected_file
=
os
.
path
.
join
(
data_dir
,
"expected_graph"
)
with
tf
.
gfile
.
Open
(
expected_file
,
"rb"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
expected_file
,
"rb"
)
as
f
:
expected_graph_bytes
=
f
.
read
()
expected_graph_bytes
=
f
.
read
()
# The serialization is non-deterministic byte-for-byte. Instead there is
# The serialization is non-deterministic byte-for-byte. Instead there is
# a utility which evaluates the semantics of the two graphs to test for
# a utility which evaluates the semantics of the two graphs to test for
...
@@ -228,19 +228,19 @@ class BaseTest(tf.test.TestCase):
...
@@ -228,19 +228,19 @@ class BaseTest(tf.test.TestCase):
graph_bytes
,
expected_graph_bytes
).
decode
(
"utf-8"
)
graph_bytes
,
expected_graph_bytes
).
decode
(
"utf-8"
)
with
graph
.
as_default
():
with
graph
.
as_default
():
init
=
tf
.
global_variables_initializer
()
init
=
tf
.
compat
.
v1
.
global_variables_initializer
()
saver
=
tf
.
train
.
Saver
()
saver
=
tf
.
compat
.
v1
.
train
.
Saver
()
with
tf
.
gfile
.
Open
(
os
.
path
.
join
(
data_dir
,
"tf_version.json"
),
"r"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
data_dir
,
"tf_version.json"
),
"r"
)
as
f
:
tf_version_reference
,
tf_git_version_reference
=
json
.
load
(
f
)
# pylint: disable=unpacking-non-sequence
tf_version_reference
,
tf_git_version_reference
=
json
.
load
(
f
)
# pylint: disable=unpacking-non-sequence
tf_version_comparison
=
""
tf_version_comparison
=
""
if
tf
.
GIT_VERSION
!=
tf_git_version_reference
:
if
tf
.
version
.
GIT_VERSION
!=
tf_git_version_reference
:
tf_version_comparison
=
(
tf_version_comparison
=
(
"Test was built using: {} (git = {})
\n
"
"Test was built using: {} (git = {})
\n
"
"Local TensorFlow version: {} (git = {})"
"Local TensorFlow version: {} (git = {})"
.
format
(
tf_version_reference
,
tf_git_version_reference
,
.
format
(
tf_version_reference
,
tf_git_version_reference
,
tf
.
VERSION
,
tf
.
GIT_VERSION
)
tf
.
version
.
VERSION
,
tf
.
version
.
GIT_VERSION
)
)
)
with
self
.
test_session
(
graph
=
graph
)
as
sess
:
with
self
.
test_session
(
graph
=
graph
)
as
sess
:
...
@@ -249,7 +249,7 @@ class BaseTest(tf.test.TestCase):
...
@@ -249,7 +249,7 @@ class BaseTest(tf.test.TestCase):
saver
.
restore
(
sess
=
sess
,
save_path
=
os
.
path
.
join
(
saver
.
restore
(
sess
=
sess
,
save_path
=
os
.
path
.
join
(
data_dir
,
self
.
ckpt_prefix
))
data_dir
,
self
.
ckpt_prefix
))
if
differences
:
if
differences
:
tf
.
logging
.
warn
(
tf
.
compat
.
v1
.
logging
.
warn
(
"The provided graph is different than expected:
\n
{}
\n
"
"The provided graph is different than expected:
\n
{}
\n
"
"However the weights were still able to be loaded.
\n
{}"
.
format
(
"However the weights were still able to be loaded.
\n
{}"
.
format
(
differences
,
tf_version_comparison
)
differences
,
tf_version_comparison
)
...
@@ -262,7 +262,7 @@ class BaseTest(tf.test.TestCase):
...
@@ -262,7 +262,7 @@ class BaseTest(tf.test.TestCase):
eval_results
=
[
op
.
eval
()
for
op
in
ops_to_eval
]
eval_results
=
[
op
.
eval
()
for
op
in
ops_to_eval
]
if
correctness_function
is
not
None
:
if
correctness_function
is
not
None
:
results
=
correctness_function
(
*
eval_results
)
results
=
correctness_function
(
*
eval_results
)
with
tf
.
gfile
.
Open
(
os
.
path
.
join
(
data_dir
,
"results.json"
),
"r"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
data_dir
,
"results.json"
),
"r"
)
as
f
:
expected_results
=
json
.
load
(
f
)
expected_results
=
json
.
load
(
f
)
self
.
assertAllClose
(
results
,
expected_results
)
self
.
assertAllClose
(
results
,
expected_results
)
...
@@ -298,7 +298,7 @@ class BaseTest(tf.test.TestCase):
...
@@ -298,7 +298,7 @@ class BaseTest(tf.test.TestCase):
correctness_function
=
correctness_function
correctness_function
=
correctness_function
)
)
except
:
except
:
tf
.
logging
.
error
(
"Failed unittest {}"
.
format
(
name
))
tf
.
compat
.
v1
.
logging
.
error
(
"Failed unittest {}"
.
format
(
name
))
raise
raise
else
:
else
:
self
.
_construct_and_save_reference_files
(
self
.
_construct_and_save_reference_files
(
...
...
official/utils/testing/reference_data_test.py
View file @
b2c9e3f5
...
@@ -63,12 +63,12 @@ class GoldenBaseTest(reference_data.BaseTest):
...
@@ -63,12 +63,12 @@ class GoldenBaseTest(reference_data.BaseTest):
with
g
.
as_default
():
with
g
.
as_default
():
seed
=
self
.
name_to_seed
(
name
)
seed
=
self
.
name_to_seed
(
name
)
seed
=
seed
+
1
if
bad_seed
else
seed
seed
=
seed
+
1
if
bad_seed
else
seed
tf
.
set_random_seed
(
seed
)
tf
.
compat
.
v1
.
set_random_seed
(
seed
)
tensor_name
=
"wrong_tensor"
if
wrong_name
else
"input_tensor"
tensor_name
=
"wrong_tensor"
if
wrong_name
else
"input_tensor"
tensor_shape
=
(
1
,
2
)
if
wrong_shape
else
(
1
,
1
)
tensor_shape
=
(
1
,
2
)
if
wrong_shape
else
(
1
,
1
)
input_tensor
=
tf
.
get_variable
(
input_tensor
=
tf
.
compat
.
v1
.
get_variable
(
tensor_name
,
dtype
=
tf
.
float32
,
tensor_name
,
dtype
=
tf
.
float32
,
initializer
=
tf
.
random
_
uniform
(
tensor_shape
,
maxval
=
1
)
initializer
=
tf
.
random
.
uniform
(
tensor_shape
,
maxval
=
1
)
)
)
def
correctness_function
(
tensor_result
):
def
correctness_function
(
tensor_result
):
...
@@ -86,13 +86,13 @@ class GoldenBaseTest(reference_data.BaseTest):
...
@@ -86,13 +86,13 @@ class GoldenBaseTest(reference_data.BaseTest):
g
=
tf
.
Graph
()
g
=
tf
.
Graph
()
with
g
.
as_default
():
with
g
.
as_default
():
tf
.
set_random_seed
(
self
.
name_to_seed
(
name
))
tf
.
compat
.
v1
.
set_random_seed
(
self
.
name_to_seed
(
name
))
input_tensor
=
tf
.
get_variable
(
input_tensor
=
tf
.
compat
.
v1
.
get_variable
(
"input_tensor"
,
dtype
=
tf
.
float32
,
"input_tensor"
,
dtype
=
tf
.
float32
,
initializer
=
tf
.
random
_
uniform
((
1
,
2
),
maxval
=
1
)
initializer
=
tf
.
random
.
uniform
((
1
,
2
),
maxval
=
1
)
)
)
layer
=
tf
.
layers
.
dense
(
inputs
=
input_tensor
,
units
=
4
)
layer
=
tf
.
compat
.
v1
.
layers
.
dense
(
inputs
=
input_tensor
,
units
=
4
)
layer
=
tf
.
layers
.
dense
(
inputs
=
layer
,
units
=
1
)
layer
=
tf
.
compat
.
v1
.
layers
.
dense
(
inputs
=
layer
,
units
=
1
)
self
.
_save_or_test_ops
(
self
.
_save_or_test_ops
(
name
=
name
,
graph
=
g
,
ops_to_eval
=
[
layer
],
test
=
test
,
name
=
name
,
graph
=
g
,
ops_to_eval
=
[
layer
],
test
=
test
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment