Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1d76d3ae
Unverified
Commit
1d76d3ae
authored
May 09, 2018
by
Qianli Scott Zhu
Committed by
GitHub
May 09, 2018
Browse files
Add benchmark logging for wide_deep. (#4220)
* Add benchmark logging for wide_deep. * Fix lint error.
parent
d0b6a34b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
3 deletions
+18
-3
official/wide_deep/wide_deep.py
official/wide_deep/wide_deep.py
+18
-3
No files found.
official/wide_deep/wide_deep.py
View file @
1d76d3ae
...
...
@@ -26,6 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
hooks_helper
from
official.utils.logs
import
logger
from
official.utils.misc
import
model_helpers
...
...
@@ -51,6 +52,7 @@ LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
def
define_wide_deep_flags
():
"""Add supervised learning flags, as well as wide-deep model type."""
flags_core
.
define_base
()
flags_core
.
define_benchmark
()
flags
.
adopt_module_key_flags
(
flags_core
)
...
...
@@ -237,6 +239,15 @@ def run_wide_deep(flags_obj):
def
eval_input_fn
():
return
input_fn
(
test_file
,
1
,
False
,
flags_obj
.
batch_size
)
run_params
=
{
'batch_size'
:
flags_obj
.
batch_size
,
'train_epochs'
:
flags_obj
.
train_epochs
,
'model_type'
:
flags_obj
.
model_type
,
}
benchmark_logger
=
logger
.
config_benchmark_logger
(
flags_obj
.
benchmark_log_dir
)
benchmark_logger
.
log_run_info
(
'wide_deep'
,
'Census Income'
,
run_params
)
loss_prefix
=
LOSS_PREFIX
.
get
(
flags_obj
.
model_type
,
''
)
train_hooks
=
hooks_helper
.
get_train_hooks
(
flags_obj
.
hooks
,
batch_size
=
flags_obj
.
batch_size
,
...
...
@@ -249,11 +260,15 @@ def run_wide_deep(flags_obj):
results
=
model
.
evaluate
(
input_fn
=
eval_input_fn
)
# Display evaluation metrics
print
(
'Results at epoch'
,
(
n
+
1
)
*
flags_obj
.
epochs_between_evals
)
print
(
'-'
*
60
)
tf
.
logging
.
info
(
'Results at epoch %d / %d'
,
(
n
+
1
)
*
flags_obj
.
epochs_between_evals
,
flags_obj
.
train_epochs
)
tf
.
logging
.
info
(
'-'
*
60
)
for
key
in
sorted
(
results
):
print
(
'%s: %s'
%
(
key
,
results
[
key
]))
tf
.
logging
.
info
(
'%s: %s'
%
(
key
,
results
[
key
]))
benchmark_logger
.
log_evaluation_result
(
results
)
if
model_helpers
.
past_stop_threshold
(
flags_obj
.
stop_threshold
,
results
[
'accuracy'
]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment