Unverified Commit 8f63feaa authored by Yanhui Liang's avatar Yanhui Liang Committed by GitHub
Browse files

Update resnet with logging utils (#3586)

* Update resnet with logging utils

* intermediate commit

* commit before rebase from master

* Add tests of ExamplePerSecondHook

* Done with test

* Fix a style nit

* Fix a style nit
parent 137e750b
...@@ -37,6 +37,7 @@ import os ...@@ -37,6 +37,7 @@ import os
import tensorflow as tf import tensorflow as tf
from official.utils.arg_parsers import parsers # pylint: disable=g-bad-import-order from official.utils.arg_parsers import parsers # pylint: disable=g-bad-import-order
from official.utils.logging import hooks_helper
_BATCH_NORM_DECAY = 0.997 _BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5 _BATCH_NORM_EPSILON = 1e-5
...@@ -748,14 +749,7 @@ def resnet_main(flags, model_function, input_function): ...@@ -748,14 +749,7 @@ def resnet_main(flags, model_function, input_function):
}) })
for _ in range(flags.train_epochs // flags.epochs_per_eval): for _ in range(flags.train_epochs // flags.epochs_per_eval):
tensors_to_log = { train_hooks = hooks_helper.get_train_hooks(flags.hooks, batch_size=flags.batch_size)
'learning_rate': 'learning_rate',
'cross_entropy': 'cross_entropy',
'train_accuracy': 'train_accuracy'
}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100)
print('Starting a training cycle.') print('Starting a training cycle.')
...@@ -764,7 +758,7 @@ def resnet_main(flags, model_function, input_function): ...@@ -764,7 +758,7 @@ def resnet_main(flags, model_function, input_function):
flags.epochs_per_eval, flags.num_parallel_calls, flags.epochs_per_eval, flags.num_parallel_calls,
flags.multi_gpu) flags.multi_gpu)
classifier.train(input_fn=input_fn_train, hooks=[logging_hook]) classifier.train(input_fn=input_fn_train, hooks=train_hooks)
print('Starting to evaluate.') print('Starting to evaluate.')
# Evaluate the model and print results # Evaluate the model and print results
......
...@@ -21,8 +21,8 @@ parsers in official models. For instance, one might define a new class: ...@@ -21,8 +21,8 @@ parsers in official models. For instance, one might define a new class:
class ExampleParser(argparse.ArgumentParser): class ExampleParser(argparse.ArgumentParser):
def __init__(self): def __init__(self):
super(ExampleParser, self).__init__(parents=[ super(ExampleParser, self).__init__(parents=[
official.utils.arg_parsers.LocationParser(data_dir=True, model_dir=True), arg_parsers.LocationParser(data_dir=True, model_dir=True),
official.utils.arg_parsers.DummyParser(use_synthetic_data=True), arg_parsers.DummyParser(use_synthetic_data=True),
]) ])
self.add_argument( self.add_argument(
...@@ -68,11 +68,12 @@ class BaseParser(argparse.ArgumentParser): ...@@ -68,11 +68,12 @@ class BaseParser(argparse.ArgumentParser):
epochs_per_eval: Create a flag to specify the frequency of testing. epochs_per_eval: Create a flag to specify the frequency of testing.
batch_size: Create a flag to specify the batch size. batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs. multi_gpu: Create a flag to allow the use of all available GPUs.
hooks: Create a flag to specify hooks for logging.
""" """
def __init__(self, add_help=False, data_dir=True, model_dir=True, def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_per_eval=True, batch_size=True, train_epochs=True, epochs_per_eval=True, batch_size=True,
multi_gpu=True): multi_gpu=True, hooks=True):
super(BaseParser, self).__init__(add_help=add_help) super(BaseParser, self).__init__(add_help=add_help)
if data_dir: if data_dir:
...@@ -117,6 +118,18 @@ class BaseParser(argparse.ArgumentParser): ...@@ -117,6 +118,18 @@ class BaseParser(argparse.ArgumentParser):
help="If set, run across all available GPUs." help="If set, run across all available GPUs."
) )
if hooks:
self.add_argument(
"--hooks", "-hk", nargs="+", default=["LoggingTensorHook"],
help="[default: %(default)s] A list of strings to specify the names "
"of train hooks. "
"Example: --hooks LoggingTensorHook ExamplesPerSecondHook. "
"Allowed hook names (case-insensitive): LoggingTensorHook, "
"ProfilerHook, ExamplesPerSecondHook. "
"See official.utils.logging.hooks_helper for details.",
metavar="<HK>"
)
class PerformanceParser(argparse.ArgumentParser): class PerformanceParser(argparse.ArgumentParser):
"""Default parser for specifying performance tuning arguments. """Default parser for specifying performance tuning arguments.
...@@ -190,5 +203,5 @@ class ImageModelParser(argparse.ArgumentParser): ...@@ -190,5 +203,5 @@ class ImageModelParser(argparse.ArgumentParser):
"always compatible with CPU. If left unspecified, the data " "always compatible with CPU. If left unspecified, the data "
"format will be chosen automatically based on whether TensorFlow" "format will be chosen automatically based on whether TensorFlow"
"was built for CPU or GPU.", "was built for CPU or GPU.",
metavar="<CF>", metavar="<CF>"
) )
...@@ -44,6 +44,7 @@ class BaseTester(unittest.TestCase): ...@@ -44,6 +44,7 @@ class BaseTester(unittest.TestCase):
train_epochs=534, train_epochs=534,
epochs_per_eval=15, epochs_per_eval=15,
batch_size=256, batch_size=256,
hooks=["LoggingTensorHook"],
num_parallel_calls=18, num_parallel_calls=18,
inter_op_parallelism_threads=5, inter_op_parallelism_threads=5,
intra_op_parallelism_thread=10, intra_op_parallelism_thread=10,
......
...@@ -96,7 +96,7 @@ def get_profiler_hook(save_steps=1000, **kwargs): # pylint: disable=unused-argu ...@@ -96,7 +96,7 @@ def get_profiler_hook(save_steps=1000, **kwargs): # pylint: disable=unused-argu
def get_examples_per_second_hook(every_n_steps=100, def get_examples_per_second_hook(every_n_steps=100,
batch_size=128, batch_size=128,
warm_steps=10, warm_steps=5,
**kwargs): # pylint: disable=unused-argument **kwargs): # pylint: disable=unused-argument
"""Function to get ExamplesPerSecondHook. """Function to get ExamplesPerSecondHook.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment