Unverified Commit 8f63feaa authored by Yanhui Liang's avatar Yanhui Liang Committed by GitHub
Browse files

Update resnet with logging utils (#3586)

* Update resnet with logging utils

* intermediate commit

* commit before rebase from master

* Add tests of ExamplePerSecondHook

* Done with test

* Fix a style nit

* Fix a style nit
parent 137e750b
......@@ -37,6 +37,7 @@ import os
import tensorflow as tf
from official.utils.arg_parsers import parsers # pylint: disable=g-bad-import-order
from official.utils.logging import hooks_helper
_BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5
......@@ -748,14 +749,7 @@ def resnet_main(flags, model_function, input_function):
})
for _ in range(flags.train_epochs // flags.epochs_per_eval):
tensors_to_log = {
'learning_rate': 'learning_rate',
'cross_entropy': 'cross_entropy',
'train_accuracy': 'train_accuracy'
}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
train_hooks = hooks_helper.get_train_hooks(flags.hooks, batch_size=flags.batch_size)
print('Starting a training cycle.')
......@@ -764,7 +758,7 @@ def resnet_main(flags, model_function, input_function):
flags.epochs_per_eval, flags.num_parallel_calls,
flags.multi_gpu)
classifier.train(input_fn=input_fn_train, hooks=[logging_hook])
classifier.train(input_fn=input_fn_train, hooks=train_hooks)
print('Starting to evaluate.')
# Evaluate the model and print results
......
......@@ -21,8 +21,8 @@ parsers in official models. For instance, one might define a new class:
class ExampleParser(argparse.ArgumentParser):
def __init__(self):
super(ExampleParser, self).__init__(parents=[
official.utils.arg_parsers.LocationParser(data_dir=True, model_dir=True),
official.utils.arg_parsers.DummyParser(use_synthetic_data=True),
arg_parsers.LocationParser(data_dir=True, model_dir=True),
arg_parsers.DummyParser(use_synthetic_data=True),
])
self.add_argument(
......@@ -68,11 +68,12 @@ class BaseParser(argparse.ArgumentParser):
epochs_per_eval: Create a flag to specify the frequency of testing.
batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs.
hooks: Create a flag to specify hooks for logging.
"""
def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_per_eval=True, batch_size=True,
multi_gpu=True):
multi_gpu=True, hooks=True):
super(BaseParser, self).__init__(add_help=add_help)
if data_dir:
......@@ -117,6 +118,18 @@ class BaseParser(argparse.ArgumentParser):
help="If set, run across all available GPUs."
)
if hooks:
self.add_argument(
"--hooks", "-hk", nargs="+", default=["LoggingTensorHook"],
help="[default: %(default)s] A list of strings to specify the names "
"of train hooks. "
"Example: --hooks LoggingTensorHook ExamplesPerSecondHook. "
"Allowed hook names (case-insensitive): LoggingTensorHook, "
"ProfilerHook, ExamplesPerSecondHook. "
"See official.utils.logging.hooks_helper for details.",
metavar="<HK>"
)
class PerformanceParser(argparse.ArgumentParser):
"""Default parser for specifying performance tuning arguments.
......@@ -190,5 +203,5 @@ class ImageModelParser(argparse.ArgumentParser):
"always compatible with CPU. If left unspecified, the data "
"format will be chosen automatically based on whether TensorFlow"
"was built for CPU or GPU.",
metavar="<CF>",
metavar="<CF>"
)
......@@ -44,6 +44,7 @@ class BaseTester(unittest.TestCase):
train_epochs=534,
epochs_per_eval=15,
batch_size=256,
hooks=["LoggingTensorHook"],
num_parallel_calls=18,
inter_op_parallelism_threads=5,
intra_op_parallelism_thread=10,
......
......@@ -96,7 +96,7 @@ def get_profiler_hook(save_steps=1000, **kwargs): # pylint: disable=unused-argu
def get_examples_per_second_hook(every_n_steps=100,
batch_size=128,
warm_steps=10,
warm_steps=5,
**kwargs): # pylint: disable=unused-argument
"""Function to get ExamplesPerSecondHook.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment