Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b2c9e3f5
Commit
b2c9e3f5
authored
Feb 08, 2019
by
Goldie Gadde
Committed by
Toby Boyd
Feb 08, 2019
Browse files
Revert "Revert "tf_upgrade_v2 on resnet and utils folders. (#6154)" (#6162)" (#6167)
This reverts commit
57e07520
.
parent
57e07520
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
121 additions
and
114 deletions
+121
-114
official/utils/flags/_device.py
official/utils/flags/_device.py
+1
-1
official/utils/logs/hooks.py
official/utils/logs/hooks.py
+4
-4
official/utils/logs/hooks_helper.py
official/utils/logs/hooks_helper.py
+3
-3
official/utils/logs/hooks_helper_test.py
official/utils/logs/hooks_helper_test.py
+1
-1
official/utils/logs/hooks_test.py
official/utils/logs/hooks_test.py
+9
-8
official/utils/logs/logger.py
official/utils/logs/logger.py
+24
-20
official/utils/logs/logger_test.py
official/utils/logs/logger_test.py
+22
-20
official/utils/logs/metric_hook.py
official/utils/logs/metric_hook.py
+2
-2
official/utils/logs/metric_hook_test.py
official/utils/logs/metric_hook_test.py
+17
-17
official/utils/logs/mlperf_helper.py
official/utils/logs/mlperf_helper.py
+2
-2
official/utils/misc/model_helpers.py
official/utils/misc/model_helpers.py
+4
-4
official/utils/misc/model_helpers_test.py
official/utils/misc/model_helpers_test.py
+9
-9
official/utils/testing/reference_data.py
official/utils/testing/reference_data.py
+15
-15
official/utils/testing/reference_data_test.py
official/utils/testing/reference_data_test.py
+8
-8
No files found.
official/utils/flags/_device.py
View file @
b2c9e3f5
...
...
@@ -39,7 +39,7 @@ def require_cloud_storage(flag_names):
valid_flags
=
True
for
key
in
flag_names
:
if
not
flag_values
[
key
].
startswith
(
"gs://"
):
tf
.
logging
.
error
(
"{} must be a GCS path."
.
format
(
key
))
tf
.
compat
.
v1
.
logging
.
error
(
"{} must be a GCS path."
.
format
(
key
))
valid_flags
=
False
return
valid_flags
...
...
official/utils/logs/hooks.py
View file @
b2c9e3f5
...
...
@@ -25,7 +25,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from
official.utils.logs
import
logger
class
ExamplesPerSecondHook
(
tf
.
train
.
SessionRunHook
):
class
ExamplesPerSecondHook
(
tf
.
estimator
.
SessionRunHook
):
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
...
...
@@ -66,7 +66,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
self
.
_logger
=
metric_logger
or
logger
.
BaseBenchmarkLogger
()
self
.
_timer
=
tf
.
train
.
SecondOrStepTimer
(
self
.
_timer
=
tf
.
estimator
.
SecondOrStepTimer
(
every_steps
=
every_n_steps
,
every_secs
=
every_n_secs
)
self
.
_step_train_time
=
0
...
...
@@ -76,7 +76,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
def
begin
(
self
):
"""Called once before using the session to check global step."""
self
.
_global_step_tensor
=
tf
.
train
.
get_global_step
()
self
.
_global_step_tensor
=
tf
.
compat
.
v1
.
train
.
get_global_step
()
if
self
.
_global_step_tensor
is
None
:
raise
RuntimeError
(
"Global step should be created to use StepCounterHook."
)
...
...
@@ -90,7 +90,7 @@ class ExamplesPerSecondHook(tf.train.SessionRunHook):
Returns:
A SessionRunArgs object or None if never triggered.
"""
return
tf
.
train
.
SessionRunArgs
(
self
.
_global_step_tensor
)
return
tf
.
estimator
.
SessionRunArgs
(
self
.
_global_step_tensor
)
def
after_run
(
self
,
run_context
,
run_values
):
# pylint: disable=unused-argument
"""Called after each call to run().
...
...
official/utils/logs/hooks_helper.py
View file @
b2c9e3f5
...
...
@@ -57,7 +57,7 @@ def get_train_hooks(name_list, use_tpu=False, **kwargs):
return
[]
if
use_tpu
:
tf
.
logging
.
warning
(
"hooks_helper received name_list `{}`, but a TPU is "
tf
.
compat
.
v1
.
logging
.
warning
(
"hooks_helper received name_list `{}`, but a TPU is "
"specified. No hooks will be used."
.
format
(
name_list
))
return
[]
...
...
@@ -89,7 +89,7 @@ def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): #
if
tensors_to_log
is
None
:
tensors_to_log
=
_TENSORS_TO_LOG
return
tf
.
train
.
LoggingTensorHook
(
return
tf
.
estimator
.
LoggingTensorHook
(
tensors
=
tensors_to_log
,
every_n_iter
=
every_n_iter
)
...
...
@@ -106,7 +106,7 @@ def get_profiler_hook(model_dir, save_steps=1000, **kwargs): # pylint: disable=
Returns a ProfilerHook that writes out timelines that can be loaded into
profiling tools like chrome://tracing.
"""
return
tf
.
train
.
ProfilerHook
(
save_steps
=
save_steps
,
output_dir
=
model_dir
)
return
tf
.
estimator
.
ProfilerHook
(
save_steps
=
save_steps
,
output_dir
=
model_dir
)
def
get_examples_per_second_hook
(
every_n_steps
=
100
,
...
...
official/utils/logs/hooks_helper_test.py
View file @
b2c9e3f5
...
...
@@ -45,7 +45,7 @@ class BaseTest(unittest.TestCase):
returned_hook
=
hooks_helper
.
get_train_hooks
(
[
test_hook_name
],
model_dir
=
""
,
**
kwargs
)
self
.
assertEqual
(
len
(
returned_hook
),
1
)
self
.
assertIsInstance
(
returned_hook
[
0
],
tf
.
train
.
SessionRunHook
)
self
.
assertIsInstance
(
returned_hook
[
0
],
tf
.
estimator
.
SessionRunHook
)
self
.
assertEqual
(
returned_hook
[
0
].
__class__
.
__name__
.
lower
(),
expected_hook_name
)
...
...
official/utils/logs/hooks_test.py
View file @
b2c9e3f5
...
...
@@ -26,7 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from
official.utils.logs
import
hooks
from
official.utils.testing
import
mock_lib
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
DEBUG
)
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
DEBUG
)
class
ExamplesPerSecondHookTest
(
tf
.
test
.
TestCase
):
...
...
@@ -44,9 +44,10 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
tf
.
train
.
create_global_step
()
self
.
train_op
=
tf
.
assign_add
(
tf
.
train
.
get_global_step
(),
1
)
self
.
global_step
=
tf
.
train
.
get_global_step
()
tf
.
compat
.
v1
.
train
.
create_global_step
()
self
.
train_op
=
tf
.
compat
.
v1
.
assign_add
(
tf
.
compat
.
v1
.
train
.
get_global_step
(),
1
)
self
.
global_step
=
tf
.
compat
.
v1
.
train
.
get_global_step
()
def
test_raise_in_both_secs_and_steps
(
self
):
with
self
.
assertRaises
(
ValueError
):
...
...
@@ -71,8 +72,8 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
warm_steps
=
warm_steps
,
metric_logger
=
self
.
_logger
)
with
tf
.
train
.
MonitoredSession
(
tf
.
train
.
ChiefSessionCreator
(),
[
hook
])
as
mon_sess
:
with
tf
.
compat
.
v1
.
train
.
MonitoredSession
(
tf
.
compat
.
v1
.
train
.
ChiefSessionCreator
(),
[
hook
])
as
mon_sess
:
for
_
in
range
(
every_n_steps
):
# Explicitly run global_step after train_op to get the accurate
# global_step value
...
...
@@ -125,8 +126,8 @@ class ExamplesPerSecondHookTest(tf.test.TestCase):
every_n_secs
=
every_n_secs
,
metric_logger
=
self
.
_logger
)
with
tf
.
train
.
MonitoredSession
(
tf
.
train
.
ChiefSessionCreator
(),
[
hook
])
as
mon_sess
:
with
tf
.
compat
.
v1
.
train
.
MonitoredSession
(
tf
.
compat
.
v1
.
train
.
ChiefSessionCreator
(),
[
hook
])
as
mon_sess
:
# Explicitly run global_step after train_op to get the accurate
# global_step value
mon_sess
.
run
(
self
.
train_op
)
...
...
official/utils/logs/logger.py
View file @
b2c9e3f5
...
...
@@ -119,12 +119,13 @@ class BaseBenchmarkLogger(object):
eval_results: dict, the result of evaluate.
"""
if
not
isinstance
(
eval_results
,
dict
):
tf
.
logging
.
warning
(
"eval_results should be dictionary for logging. "
"Got %s"
,
type
(
eval_results
))
tf
.
compat
.
v1
.
logging
.
warning
(
"eval_results should be dictionary for logging. Got %s"
,
type
(
eval_results
))
return
global_step
=
eval_results
[
tf
.
GraphKeys
.
GLOBAL_STEP
]
global_step
=
eval_results
[
tf
.
compat
.
v1
.
GraphKeys
.
GLOBAL_STEP
]
for
key
in
sorted
(
eval_results
):
if
key
!=
tf
.
GraphKeys
.
GLOBAL_STEP
:
if
key
!=
tf
.
compat
.
v1
.
GraphKeys
.
GLOBAL_STEP
:
self
.
log_metric
(
key
,
eval_results
[
key
],
global_step
=
global_step
)
def
log_metric
(
self
,
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
...
...
@@ -143,12 +144,12 @@ class BaseBenchmarkLogger(object):
"""
metric
=
_process_metric_to_json
(
name
,
value
,
unit
,
global_step
,
extras
)
if
metric
:
tf
.
logging
.
info
(
"Benchmark metric: %s"
,
metric
)
tf
.
compat
.
v1
.
logging
.
info
(
"Benchmark metric: %s"
,
metric
)
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
,
test_id
=
None
):
tf
.
logging
.
info
(
"Benchmark run: %s"
,
_gather_run_info
(
model_name
,
dataset_name
,
run_params
,
test_id
))
tf
.
compat
.
v1
.
logging
.
info
(
"Benchmark run: %s"
,
_gather_run_info
(
model_name
,
dataset_name
,
run_params
,
test_id
))
def
on_finish
(
self
,
status
):
pass
...
...
@@ -160,9 +161,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
def
__init__
(
self
,
logging_dir
):
super
(
BenchmarkFileLogger
,
self
).
__init__
()
self
.
_logging_dir
=
logging_dir
if
not
tf
.
gfile
.
IsDirectory
(
self
.
_logging_dir
):
tf
.
gfile
.
M
ake
D
irs
(
self
.
_logging_dir
)
self
.
_metric_file_handler
=
tf
.
gfile
.
GFile
(
if
not
tf
.
io
.
gfile
.
isdir
(
self
.
_logging_dir
):
tf
.
io
.
gfile
.
m
ake
d
irs
(
self
.
_logging_dir
)
self
.
_metric_file_handler
=
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
self
.
_logging_dir
,
METRIC_LOG_FILE_NAME
),
"a"
)
def
log_metric
(
self
,
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
...
...
@@ -186,8 +187,9 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
self
.
_metric_file_handler
.
write
(
"
\n
"
)
self
.
_metric_file_handler
.
flush
()
except
(
TypeError
,
ValueError
)
as
e
:
tf
.
logging
.
warning
(
"Failed to dump metric to log file: "
"name %s, value %s, error %s"
,
name
,
value
,
e
)
tf
.
compat
.
v1
.
logging
.
warning
(
"Failed to dump metric to log file: name %s, value %s, error %s"
,
name
,
value
,
e
)
def
log_run_info
(
self
,
model_name
,
dataset_name
,
run_params
,
test_id
=
None
):
"""Collect most of the TF runtime information for the local env.
...
...
@@ -204,14 +206,14 @@ class BenchmarkFileLogger(BaseBenchmarkLogger):
"""
run_info
=
_gather_run_info
(
model_name
,
dataset_name
,
run_params
,
test_id
)
with
tf
.
gfile
.
GFile
(
os
.
path
.
join
(
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
self
.
_logging_dir
,
BENCHMARK_RUN_LOG_FILE_NAME
),
"w"
)
as
f
:
try
:
json
.
dump
(
run_info
,
f
)
f
.
write
(
"
\n
"
)
except
(
TypeError
,
ValueError
)
as
e
:
tf
.
logging
.
warning
(
"Failed to dump benchmark run info to log file: %s"
,
e
)
tf
.
compat
.
v1
.
logging
.
warning
(
"Failed to dump benchmark run info to log file: %s"
,
e
)
def
on_finish
(
self
,
status
):
self
.
_metric_file_handler
.
flush
()
...
...
@@ -324,7 +326,7 @@ def _process_metric_to_json(
name
,
value
,
unit
=
None
,
global_step
=
None
,
extras
=
None
):
"""Validate the metric data and generate JSON for insert."""
if
not
isinstance
(
value
,
numbers
.
Number
):
tf
.
logging
.
warning
(
tf
.
compat
.
v1
.
logging
.
warning
(
"Metric value to log should be a number. Got %s"
,
type
(
value
))
return
None
...
...
@@ -341,7 +343,7 @@ def _process_metric_to_json(
def
_collect_tensorflow_info
(
run_info
):
run_info
[
"tensorflow_version"
]
=
{
"version"
:
tf
.
VERSION
,
"git_hash"
:
tf
.
GIT_VERSION
}
"version"
:
tf
.
version
.
VERSION
,
"git_hash"
:
tf
.
version
.
GIT_VERSION
}
def
_collect_run_params
(
run_info
,
run_params
):
...
...
@@ -385,7 +387,8 @@ def _collect_cpu_info(run_info):
run_info
[
"machine_config"
][
"cpu_info"
]
=
cpu_info
except
ImportError
:
tf
.
logging
.
warn
(
"'cpuinfo' not imported. CPU info will not be logged."
)
tf
.
compat
.
v1
.
logging
.
warn
(
"'cpuinfo' not imported. CPU info will not be logged."
)
def
_collect_gpu_info
(
run_info
,
session_config
=
None
):
...
...
@@ -415,7 +418,8 @@ def _collect_memory_info(run_info):
run_info
[
"machine_config"
][
"memory_total"
]
=
vmem
.
total
run_info
[
"machine_config"
][
"memory_available"
]
=
vmem
.
available
except
ImportError
:
tf
.
logging
.
warn
(
"'psutil' not imported. Memory info will not be logged."
)
tf
.
compat
.
v1
.
logging
.
warn
(
"'psutil' not imported. Memory info will not be logged."
)
def
_collect_test_environment
(
run_info
):
...
...
official/utils/logs/logger_test.py
View file @
b2c9e3f5
...
...
@@ -78,7 +78,7 @@ class BenchmarkLoggerTest(tf.test.TestCase):
mock_logger
=
mock
.
MagicMock
()
mock_config_benchmark_logger
.
return_value
=
mock_logger
with
logger
.
benchmark_context
(
None
):
tf
.
logging
.
info
(
"start benchmarking"
)
tf
.
compat
.
v1
.
logging
.
info
(
"start benchmarking"
)
mock_logger
.
on_finish
.
assert_called_once_with
(
logger
.
RUN_STATUS_SUCCESS
)
@
mock
.
patch
(
"official.utils.logs.logger.config_benchmark_logger"
)
...
...
@@ -95,18 +95,18 @@ class BaseBenchmarkLoggerTest(tf.test.TestCase):
def
setUp
(
self
):
super
(
BaseBenchmarkLoggerTest
,
self
).
setUp
()
self
.
_actual_log
=
tf
.
logging
.
info
self
.
_actual_log
=
tf
.
compat
.
v1
.
logging
.
info
self
.
logged_message
=
None
def
mock_log
(
*
args
,
**
kwargs
):
self
.
logged_message
=
args
self
.
_actual_log
(
*
args
,
**
kwargs
)
tf
.
logging
.
info
=
mock_log
tf
.
compat
.
v1
.
logging
.
info
=
mock_log
def
tearDown
(
self
):
super
(
BaseBenchmarkLoggerTest
,
self
).
tearDown
()
tf
.
logging
.
info
=
self
.
_actual_log
tf
.
compat
.
v1
.
logging
.
info
=
self
.
_actual_log
def
test_log_metric
(
self
):
log
=
logger
.
BaseBenchmarkLogger
()
...
...
@@ -128,16 +128,16 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
def
tearDown
(
self
):
super
(
BenchmarkFileLoggerTest
,
self
).
tearDown
()
tf
.
gfile
.
DeleteRecursively
(
self
.
get_temp_dir
())
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
os
.
environ
.
clear
()
os
.
environ
.
update
(
self
.
original_environ
)
def
test_create_logging_dir
(
self
):
non_exist_temp_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"unknown_dir"
)
self
.
assertFalse
(
tf
.
gfile
.
IsDirectory
(
non_exist_temp_dir
))
self
.
assertFalse
(
tf
.
io
.
gfile
.
isdir
(
non_exist_temp_dir
))
logger
.
BenchmarkFileLogger
(
non_exist_temp_dir
)
self
.
assertTrue
(
tf
.
gfile
.
IsDirectory
(
non_exist_temp_dir
))
self
.
assertTrue
(
tf
.
io
.
gfile
.
isdir
(
non_exist_temp_dir
))
def
test_log_metric
(
self
):
log_dir
=
tempfile
.
mkdtemp
(
dir
=
self
.
get_temp_dir
())
...
...
@@ -145,8 +145,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log
.
log_metric
(
"accuracy"
,
0.999
,
global_step
=
1e4
,
extras
=
{
"name"
:
"value"
})
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertTrue
(
tf
.
gfile
.
E
xists
(
metric_log
))
with
tf
.
gfile
.
GFile
(
metric_log
)
as
f
:
self
.
assertTrue
(
tf
.
io
.
gfile
.
e
xists
(
metric_log
))
with
tf
.
io
.
gfile
.
GFile
(
metric_log
)
as
f
:
metric
=
json
.
loads
(
f
.
readline
())
self
.
assertEqual
(
metric
[
"name"
],
"accuracy"
)
self
.
assertEqual
(
metric
[
"value"
],
0.999
)
...
...
@@ -161,8 +161,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log
.
log_metric
(
"loss"
,
0.02
,
global_step
=
1e4
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertTrue
(
tf
.
gfile
.
E
xists
(
metric_log
))
with
tf
.
gfile
.
GFile
(
metric_log
)
as
f
:
self
.
assertTrue
(
tf
.
io
.
gfile
.
e
xists
(
metric_log
))
with
tf
.
io
.
gfile
.
GFile
(
metric_log
)
as
f
:
accuracy
=
json
.
loads
(
f
.
readline
())
self
.
assertEqual
(
accuracy
[
"name"
],
"accuracy"
)
self
.
assertEqual
(
accuracy
[
"value"
],
0.999
)
...
...
@@ -184,7 +184,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log
.
log_metric
(
"accuracy"
,
const
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertFalse
(
tf
.
gfile
.
E
xists
(
metric_log
))
self
.
assertFalse
(
tf
.
io
.
gfile
.
e
xists
(
metric_log
))
def
test_log_evaluation_result
(
self
):
eval_result
=
{
"loss"
:
0.46237424
,
...
...
@@ -195,8 +195,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log
.
log_evaluation_result
(
eval_result
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertTrue
(
tf
.
gfile
.
E
xists
(
metric_log
))
with
tf
.
gfile
.
GFile
(
metric_log
)
as
f
:
self
.
assertTrue
(
tf
.
io
.
gfile
.
e
xists
(
metric_log
))
with
tf
.
io
.
gfile
.
GFile
(
metric_log
)
as
f
:
accuracy
=
json
.
loads
(
f
.
readline
())
self
.
assertEqual
(
accuracy
[
"name"
],
"accuracy"
)
self
.
assertEqual
(
accuracy
[
"value"
],
0.9285
)
...
...
@@ -216,7 +216,7 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log
.
log_evaluation_result
(
eval_result
)
metric_log
=
os
.
path
.
join
(
log_dir
,
"metric.log"
)
self
.
assertFalse
(
tf
.
gfile
.
E
xists
(
metric_log
))
self
.
assertFalse
(
tf
.
io
.
gfile
.
e
xists
(
metric_log
))
@
mock
.
patch
(
"official.utils.logs.logger._gather_run_info"
)
def
test_log_run_info
(
self
,
mock_gather_run_info
):
...
...
@@ -229,8 +229,8 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
log
.
log_run_info
(
"model_name"
,
"dataset_name"
,
{})
run_log
=
os
.
path
.
join
(
log_dir
,
"benchmark_run.log"
)
self
.
assertTrue
(
tf
.
gfile
.
E
xists
(
run_log
))
with
tf
.
gfile
.
GFile
(
run_log
)
as
f
:
self
.
assertTrue
(
tf
.
io
.
gfile
.
e
xists
(
run_log
))
with
tf
.
io
.
gfile
.
GFile
(
run_log
)
as
f
:
run_info
=
json
.
loads
(
f
.
readline
())
self
.
assertEqual
(
run_info
[
"model_name"
],
"model_name"
)
self
.
assertEqual
(
run_info
[
"dataset"
],
"dataset_name"
)
...
...
@@ -240,8 +240,10 @@ class BenchmarkFileLoggerTest(tf.test.TestCase):
run_info
=
{}
logger
.
_collect_tensorflow_info
(
run_info
)
self
.
assertNotEqual
(
run_info
[
"tensorflow_version"
],
{})
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"version"
],
tf
.
VERSION
)
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"git_hash"
],
tf
.
GIT_VERSION
)
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"version"
],
tf
.
version
.
VERSION
)
self
.
assertEqual
(
run_info
[
"tensorflow_version"
][
"git_hash"
],
tf
.
version
.
GIT_VERSION
)
def
test_collect_run_params
(
self
):
run_info
=
{}
...
...
@@ -315,7 +317,7 @@ class BenchmarkBigQueryLoggerTest(tf.test.TestCase):
def
tearDown
(
self
):
super
(
BenchmarkBigQueryLoggerTest
,
self
).
tearDown
()
tf
.
gfile
.
DeleteRecursively
(
self
.
get_temp_dir
())
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
os
.
environ
.
clear
()
os
.
environ
.
update
(
self
.
original_environ
)
...
...
official/utils/logs/metric_hook.py
View file @
b2c9e3f5
...
...
@@ -21,7 +21,7 @@ from __future__ import print_function
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
class
LoggingMetricHook
(
tf
.
train
.
LoggingTensorHook
):
class
LoggingMetricHook
(
tf
.
estimator
.
LoggingTensorHook
):
"""Hook to log benchmark metric information.
This hook is very similar as tf.train.LoggingTensorHook, which logs given
...
...
@@ -68,7 +68,7 @@ class LoggingMetricHook(tf.train.LoggingTensorHook):
def
begin
(
self
):
super
(
LoggingMetricHook
,
self
).
begin
()
self
.
_global_step_tensor
=
tf
.
train
.
get_global_step
()
self
.
_global_step_tensor
=
tf
.
compat
.
v1
.
train
.
get_global_step
()
if
self
.
_global_step_tensor
is
None
:
raise
RuntimeError
(
"Global step should be created to use LoggingMetricHook."
)
...
...
official/utils/logs/metric_hook_test.py
View file @
b2c9e3f5
...
...
@@ -39,7 +39,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
def
tearDown
(
self
):
super
(
LoggingMetricHookTest
,
self
).
tearDown
()
tf
.
gfile
.
DeleteRecursively
(
self
.
get_temp_dir
())
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
def
test_illegal_args
(
self
):
with
self
.
assertRaisesRegexp
(
ValueError
,
"nvalid every_n_iter"
):
...
...
@@ -55,15 +55,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
metric_hook
.
LoggingMetricHook
(
tensors
=
[
"t"
],
every_n_iter
=
5
)
def
test_print_at_end_only
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
:
tf
.
train
.
get_or_create_global_step
()
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
t
=
tf
.
constant
(
42.0
,
name
=
"foo"
)
train_op
=
tf
.
constant
(
3
)
hook
=
metric_hook
.
LoggingMetricHook
(
tensors
=
[
t
.
name
],
at_end
=
True
,
metric_logger
=
self
.
_logger
)
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
for
_
in
range
(
3
):
mon_sess
.
run
(
train_op
)
...
...
@@ -88,8 +88,8 @@ class LoggingMetricHookTest(tf.test.TestCase):
hook
.
begin
()
def
test_log_tensors
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
:
tf
.
train
.
get_or_create_global_step
()
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
t1
=
tf
.
constant
(
42.0
,
name
=
"foo"
)
t2
=
tf
.
constant
(
43.0
,
name
=
"bar"
)
train_op
=
tf
.
constant
(
3
)
...
...
@@ -97,7 +97,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
tensors
=
[
t1
,
t2
],
at_end
=
True
,
metric_logger
=
self
.
_logger
)
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
for
_
in
range
(
3
):
mon_sess
.
run
(
train_op
)
...
...
@@ -126,7 +126,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
metric_logger
=
self
.
_logger
)
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
mon_sess
.
run
(
train_op
)
self
.
assertRegexpMatches
(
str
(
self
.
_logger
.
logged_metric
),
t
.
name
)
for
_
in
range
(
3
):
...
...
@@ -153,15 +153,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
self
.
assertEqual
(
str
(
self
.
_logger
.
logged_metric
).
find
(
t
.
name
),
-
1
)
def
test_print_every_n_steps
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
:
tf
.
train
.
get_or_create_global_step
()
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
False
)
# Verify proper reset.
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
False
)
def
test_print_every_n_steps_and_end
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
:
tf
.
train
.
get_or_create_global_step
()
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
True
)
# Verify proper reset.
self
.
_validate_print_every_n_steps
(
sess
,
at_end
=
True
)
...
...
@@ -175,7 +175,7 @@ class LoggingMetricHookTest(tf.test.TestCase):
metric_logger
=
self
.
_logger
)
hook
.
begin
()
mon_sess
=
monitored_session
.
_HookedSession
(
sess
,
[
hook
])
# pylint: disable=protected-access
sess
.
run
(
tf
.
global_variables_initializer
())
sess
.
run
(
tf
.
compat
.
v1
.
global_variables_initializer
())
mon_sess
.
run
(
train_op
)
self
.
assertRegexpMatches
(
str
(
self
.
_logger
.
logged_metric
),
t
.
name
)
...
...
@@ -199,15 +199,15 @@ class LoggingMetricHookTest(tf.test.TestCase):
self
.
assertEqual
(
str
(
self
.
_logger
.
logged_metric
).
find
(
t
.
name
),
-
1
)
def
test_print_every_n_secs
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
:
tf
.
train
.
get_or_create_global_step
()
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
False
)
# Verify proper reset.
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
False
)
def
test_print_every_n_secs_and_end
(
self
):
with
tf
.
Graph
().
as_default
(),
tf
.
Session
()
as
sess
:
tf
.
train
.
get_or_create_global_step
()
with
tf
.
Graph
().
as_default
(),
tf
.
compat
.
v1
.
Session
()
as
sess
:
tf
.
compat
.
v1
.
train
.
get_or_create_global_step
()
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
True
)
# Verify proper reset.
self
.
_validate_print_every_n_secs
(
sess
,
at_end
=
True
)
...
...
official/utils/logs/mlperf_helper.py
View file @
b2c9e3f5
...
...
@@ -94,7 +94,7 @@ def get_mlperf_log():
version
=
pkg_resources
.
get_distribution
(
"mlperf_compliance"
)
version
=
tuple
(
int
(
i
)
for
i
in
version
.
version
.
split
(
"."
))
if
version
<
_MIN_VERSION
:
tf
.
logging
.
warning
(
tf
.
compat
.
v1
.
logging
.
warning
(
"mlperf_compliance is version {}, must be >= {}"
.
format
(
"."
.
join
([
str
(
i
)
for
i
in
version
]),
"."
.
join
([
str
(
i
)
for
i
in
_MIN_VERSION
])))
...
...
@@ -187,6 +187,6 @@ def clear_system_caches():
if
__name__
==
"__main__"
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
compat
.
v1
.
logging
.
set_verbosity
(
tf
.
compat
.
v1
.
logging
.
INFO
)
with
LOGGER
(
True
):
ncf_print
(
key
=
TAGS
.
RUN_START
)
official/utils/misc/model_helpers.py
View file @
b2c9e3f5
...
...
@@ -48,7 +48,7 @@ def past_stop_threshold(stop_threshold, eval_metric):
"must be a number."
)
if
eval_metric
>=
stop_threshold
:
tf
.
logging
.
info
(
tf
.
compat
.
v1
.
logging
.
info
(
"Stop threshold of {} was passed with metric value {}."
.
format
(
stop_threshold
,
eval_metric
))
return
True
...
...
@@ -87,7 +87,7 @@ def generate_synthetic_data(
def
apply_clean
(
flags_obj
):
if
flags_obj
.
clean
and
tf
.
gfile
.
E
xists
(
flags_obj
.
model_dir
):
tf
.
logging
.
info
(
"--clean flag set. Removing existing model dir: {}"
.
format
(
if
flags_obj
.
clean
and
tf
.
io
.
gfile
.
e
xists
(
flags_obj
.
model_dir
):
tf
.
compat
.
v1
.
logging
.
info
(
"--clean flag set. Removing existing model dir: {}"
.
format
(
flags_obj
.
model_dir
))
tf
.
gfile
.
DeleteRecursively
(
flags_obj
.
model_dir
)
tf
.
io
.
gfile
.
rmtree
(
flags_obj
.
model_dir
)
official/utils/misc/model_helpers_test.py
View file @
b2c9e3f5
...
...
@@ -69,13 +69,13 @@ class SyntheticDataTest(tf.test.TestCase):
"""Tests for generate_synthetic_data."""
def
test_generate_synethetic_data
(
self
):
input_element
,
label_element
=
model_helpers
.
generate_synthetic_data
(
input_shape
=
tf
.
TensorShape
([
5
]),
input_value
=
123
,
input_dtype
=
tf
.
float32
,
label_shape
=
tf
.
TensorShape
([]),
label_value
=
456
,
label_dtype
=
tf
.
int32
)
.
make_one_shot_iterator
(
).
get_next
()
input_element
,
label_element
=
tf
.
compat
.
v1
.
data
.
make_one_shot_iterator
(
model_helpers
.
generate_synthetic_data
(
input_shape
=
tf
.
TensorShape
([
5
]),
input_value
=
123
,
input_dtype
=
tf
.
float32
,
label_shape
=
tf
.
TensorShape
([]),
label_value
=
456
,
label_dtype
=
tf
.
int32
)).
get_next
()
with
self
.
test_session
()
as
sess
:
for
n
in
range
(
5
):
...
...
@@ -89,7 +89,7 @@ class SyntheticDataTest(tf.test.TestCase):
input_value
=
43.5
,
input_dtype
=
tf
.
float32
)
element
=
d
.
make_one_shot_iterator
().
get_next
()
element
=
tf
.
compat
.
v1
.
data
.
make_one_shot_iterator
(
d
).
get_next
()
self
.
assertFalse
(
isinstance
(
element
,
tuple
))
with
self
.
test_session
()
as
sess
:
...
...
@@ -102,7 +102,7 @@ class SyntheticDataTest(tf.test.TestCase):
'b'
:
{
'c'
:
tf
.
TensorShape
([
3
]),
'd'
:
tf
.
TensorShape
([])}},
input_value
=
1.1
)
element
=
d
.
make_one_shot_iterator
().
get_next
()
element
=
tf
.
compat
.
v1
.
data
.
make_one_shot_iterator
(
d
).
get_next
()
self
.
assertIn
(
'a'
,
element
)
self
.
assertIn
(
'b'
,
element
)
self
.
assertEquals
(
len
(
element
[
'b'
]),
2
)
...
...
official/utils/testing/reference_data.py
View file @
b2c9e3f5
...
...
@@ -170,12 +170,12 @@ class BaseTest(tf.test.TestCase):
# Serialize graph for comparison.
graph_bytes
=
graph
.
as_graph_def
().
SerializeToString
()
expected_file
=
os
.
path
.
join
(
data_dir
,
"expected_graph"
)
with
tf
.
gfile
.
Open
(
expected_file
,
"wb"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
expected_file
,
"wb"
)
as
f
:
f
.
write
(
graph_bytes
)
with
graph
.
as_default
():
init
=
tf
.
global_variables_initializer
()
saver
=
tf
.
train
.
Saver
()
init
=
tf
.
compat
.
v1
.
global_variables_initializer
()
saver
=
tf
.
compat
.
v1
.
train
.
Saver
()
with
self
.
test_session
(
graph
=
graph
)
as
sess
:
sess
.
run
(
init
)
...
...
@@ -191,11 +191,11 @@ class BaseTest(tf.test.TestCase):
if
correctness_function
is
not
None
:
results
=
correctness_function
(
*
eval_results
)
with
tf
.
gfile
.
Open
(
os
.
path
.
join
(
data_dir
,
"results.json"
),
"w"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
data_dir
,
"results.json"
),
"w"
)
as
f
:
json
.
dump
(
results
,
f
)
with
tf
.
gfile
.
Open
(
os
.
path
.
join
(
data_dir
,
"tf_version.json"
),
"w"
)
as
f
:
json
.
dump
([
tf
.
VERSION
,
tf
.
GIT_VERSION
],
f
)
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
data_dir
,
"tf_version.json"
),
"w"
)
as
f
:
json
.
dump
([
tf
.
version
.
VERSION
,
tf
.
version
.
GIT_VERSION
],
f
)
def
_evaluate_test_case
(
self
,
name
,
graph
,
ops_to_eval
,
correctness_function
):
"""Determine if a graph agrees with the reference data.
...
...
@@ -216,7 +216,7 @@ class BaseTest(tf.test.TestCase):
# Serialize graph for comparison.
graph_bytes
=
graph
.
as_graph_def
().
SerializeToString
()
expected_file
=
os
.
path
.
join
(
data_dir
,
"expected_graph"
)
with
tf
.
gfile
.
Open
(
expected_file
,
"rb"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
expected_file
,
"rb"
)
as
f
:
expected_graph_bytes
=
f
.
read
()
# The serialization is non-deterministic byte-for-byte. Instead there is
# a utility which evaluates the semantics of the two graphs to test for
...
...
@@ -228,19 +228,19 @@ class BaseTest(tf.test.TestCase):
graph_bytes
,
expected_graph_bytes
).
decode
(
"utf-8"
)
with
graph
.
as_default
():
init
=
tf
.
global_variables_initializer
()
saver
=
tf
.
train
.
Saver
()
init
=
tf
.
compat
.
v1
.
global_variables_initializer
()
saver
=
tf
.
compat
.
v1
.
train
.
Saver
()
with
tf
.
gfile
.
Open
(
os
.
path
.
join
(
data_dir
,
"tf_version.json"
),
"r"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
data_dir
,
"tf_version.json"
),
"r"
)
as
f
:
tf_version_reference
,
tf_git_version_reference
=
json
.
load
(
f
)
# pylint: disable=unpacking-non-sequence
tf_version_comparison
=
""
if
tf
.
GIT_VERSION
!=
tf_git_version_reference
:
if
tf
.
version
.
GIT_VERSION
!=
tf_git_version_reference
:
tf_version_comparison
=
(
"Test was built using: {} (git = {})
\n
"
"Local TensorFlow version: {} (git = {})"
.
format
(
tf_version_reference
,
tf_git_version_reference
,
tf
.
VERSION
,
tf
.
GIT_VERSION
)
tf
.
version
.
VERSION
,
tf
.
version
.
GIT_VERSION
)
)
with
self
.
test_session
(
graph
=
graph
)
as
sess
:
...
...
@@ -249,7 +249,7 @@ class BaseTest(tf.test.TestCase):
saver
.
restore
(
sess
=
sess
,
save_path
=
os
.
path
.
join
(
data_dir
,
self
.
ckpt_prefix
))
if
differences
:
tf
.
logging
.
warn
(
tf
.
compat
.
v1
.
logging
.
warn
(
"The provided graph is different than expected:
\n
{}
\n
"
"However the weights were still able to be loaded.
\n
{}"
.
format
(
differences
,
tf_version_comparison
)
...
...
@@ -262,7 +262,7 @@ class BaseTest(tf.test.TestCase):
eval_results
=
[
op
.
eval
()
for
op
in
ops_to_eval
]
if
correctness_function
is
not
None
:
results
=
correctness_function
(
*
eval_results
)
with
tf
.
gfile
.
Open
(
os
.
path
.
join
(
data_dir
,
"results.json"
),
"r"
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
os
.
path
.
join
(
data_dir
,
"results.json"
),
"r"
)
as
f
:
expected_results
=
json
.
load
(
f
)
self
.
assertAllClose
(
results
,
expected_results
)
...
...
@@ -298,7 +298,7 @@ class BaseTest(tf.test.TestCase):
correctness_function
=
correctness_function
)
except
:
tf
.
logging
.
error
(
"Failed unittest {}"
.
format
(
name
))
tf
.
compat
.
v1
.
logging
.
error
(
"Failed unittest {}"
.
format
(
name
))
raise
else
:
self
.
_construct_and_save_reference_files
(
...
...
official/utils/testing/reference_data_test.py
View file @
b2c9e3f5
...
...
@@ -63,12 +63,12 @@ class GoldenBaseTest(reference_data.BaseTest):
with
g
.
as_default
():
seed
=
self
.
name_to_seed
(
name
)
seed
=
seed
+
1
if
bad_seed
else
seed
tf
.
set_random_seed
(
seed
)
tf
.
compat
.
v1
.
set_random_seed
(
seed
)
tensor_name
=
"wrong_tensor"
if
wrong_name
else
"input_tensor"
tensor_shape
=
(
1
,
2
)
if
wrong_shape
else
(
1
,
1
)
input_tensor
=
tf
.
get_variable
(
input_tensor
=
tf
.
compat
.
v1
.
get_variable
(
tensor_name
,
dtype
=
tf
.
float32
,
initializer
=
tf
.
random
_
uniform
(
tensor_shape
,
maxval
=
1
)
initializer
=
tf
.
random
.
uniform
(
tensor_shape
,
maxval
=
1
)
)
def
correctness_function
(
tensor_result
):
...
...
@@ -86,13 +86,13 @@ class GoldenBaseTest(reference_data.BaseTest):
g
=
tf
.
Graph
()
with
g
.
as_default
():
tf
.
set_random_seed
(
self
.
name_to_seed
(
name
))
input_tensor
=
tf
.
get_variable
(
tf
.
compat
.
v1
.
set_random_seed
(
self
.
name_to_seed
(
name
))
input_tensor
=
tf
.
compat
.
v1
.
get_variable
(
"input_tensor"
,
dtype
=
tf
.
float32
,
initializer
=
tf
.
random
_
uniform
((
1
,
2
),
maxval
=
1
)
initializer
=
tf
.
random
.
uniform
((
1
,
2
),
maxval
=
1
)
)
layer
=
tf
.
layers
.
dense
(
inputs
=
input_tensor
,
units
=
4
)
layer
=
tf
.
layers
.
dense
(
inputs
=
layer
,
units
=
1
)
layer
=
tf
.
compat
.
v1
.
layers
.
dense
(
inputs
=
input_tensor
,
units
=
4
)
layer
=
tf
.
compat
.
v1
.
layers
.
dense
(
inputs
=
layer
,
units
=
1
)
self
.
_save_or_test_ops
(
name
=
name
,
graph
=
g
,
ops_to_eval
=
[
layer
],
test
=
test
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment