Unverified Commit 1d76d3ae authored by Qianli Scott Zhu's avatar Qianli Scott Zhu Committed by GitHub
Browse files

Add benchmark logging for wide_deep. (#4220)

* Add benchmark logging for wide_deep.

* Fix lint error.
parent d0b6a34b
......@@ -26,6 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.logs import logger
from official.utils.misc import model_helpers
......@@ -51,6 +52,7 @@ LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type."""
flags_core.define_base()
flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
......@@ -237,6 +239,15 @@ def run_wide_deep(flags_obj):
def eval_input_fn():
return input_fn(test_file, 1, False, flags_obj.batch_size)
run_params = {
'batch_size': flags_obj.batch_size,
'train_epochs': flags_obj.train_epochs,
'model_type': flags_obj.model_type,
}
benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir)
benchmark_logger.log_run_info('wide_deep', 'Census Income', run_params)
loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
train_hooks = hooks_helper.get_train_hooks(
flags_obj.hooks, batch_size=flags_obj.batch_size,
......@@ -249,11 +260,15 @@ def run_wide_deep(flags_obj):
results = model.evaluate(input_fn=eval_input_fn)
# Display evaluation metrics
print('Results at epoch', (n + 1) * flags_obj.epochs_between_evals)
print('-' * 60)
tf.logging.info('Results at epoch %d / %d',
(n + 1) * flags_obj.epochs_between_evals,
flags_obj.train_epochs)
tf.logging.info('-' * 60)
for key in sorted(results):
print('%s: %s' % (key, results[key]))
tf.logging.info('%s: %s' % (key, results[key]))
benchmark_logger.log_evaluation_result(results)
if model_helpers.past_stop_threshold(
flags_obj.stop_threshold, results['accuracy']):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment