Unverified Commit 5f9f6b84 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Move argparsing from builtin argparse to absl (#4099)

* squash of modular absl usage commits

* delint

* address PR comments

* change hooks to comma separated list, as absl behavior for space separated lists is not as expected
parent 6ec3452c
...@@ -20,13 +20,16 @@ from __future__ import print_function ...@@ -20,13 +20,16 @@ from __future__ import print_function
import argparse import argparse
import sys import sys
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.mnist import dataset from official.mnist import dataset
from official.utils.arg_parsers import parsers from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
LEARNING_RATE = 1e-4 LEARNING_RATE = 1e-4
...@@ -86,6 +89,16 @@ def create_model(data_format): ...@@ -86,6 +89,16 @@ def create_model(data_format):
]) ])
def define_mnist_flags():
flags_core.define_base()
flags_core.define_image()
flags.adopt_module_key_flags(flags_core)
flags_core.set_defaults(data_dir='/tmp/mnist_data',
model_dir='/tmp/mnist_model',
batch_size=100,
train_epochs=40)
def model_fn(features, labels, mode, params): def model_fn(features, labels, mode, params):
"""The model_fn argument for creating an Estimator.""" """The model_fn argument for creating an Estimator."""
model = create_model(params['data_format']) model = create_model(params['data_format'])
...@@ -172,14 +185,11 @@ def validate_batch_size_for_multi_gpu(batch_size): ...@@ -172,14 +185,11 @@ def validate_batch_size_for_multi_gpu(batch_size):
raise ValueError(err) raise ValueError(err)
def main(argv): def main(flags_obj):
parser = MNISTArgParser()
flags = parser.parse_args(args=argv[1:])
model_function = model_fn model_function = model_fn
if flags.multi_gpu: if flags_obj.multi_gpu:
validate_batch_size_for_multi_gpu(flags.batch_size) validate_batch_size_for_multi_gpu(flags_obj.batch_size)
# There are two steps required if using multi-GPU: (1) wrap the model_fn, # There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens # and (2) wrap the optimizer. The first happens here, and (2) happens
...@@ -187,16 +197,16 @@ def main(argv): ...@@ -187,16 +197,16 @@ def main(argv):
model_function = tf.contrib.estimator.replicate_model_fn( model_function = tf.contrib.estimator.replicate_model_fn(
model_fn, loss_reduction=tf.losses.Reduction.MEAN) model_fn, loss_reduction=tf.losses.Reduction.MEAN)
data_format = flags.data_format data_format = flags_obj.data_format
if data_format is None: if data_format is None:
data_format = ('channels_first' data_format = ('channels_first'
if tf.test.is_built_with_cuda() else 'channels_last') if tf.test.is_built_with_cuda() else 'channels_last')
mnist_classifier = tf.estimator.Estimator( mnist_classifier = tf.estimator.Estimator(
model_fn=model_function, model_fn=model_function,
model_dir=flags.model_dir, model_dir=flags_obj.model_dir,
params={ params={
'data_format': data_format, 'data_format': data_format,
'multi_gpu': flags.multi_gpu 'multi_gpu': flags_obj.multi_gpu
}) })
# Set up training and evaluation input functions. # Set up training and evaluation input functions.
...@@ -206,57 +216,42 @@ def main(argv): ...@@ -206,57 +216,42 @@ def main(argv):
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small # randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch. # enough dataset that we can easily shuffle the full epoch.
ds = dataset.train(flags.data_dir) ds = dataset.train(flags_obj.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(flags.batch_size) ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
# Iterate through the dataset a set number (`epochs_between_evals`) of times # Iterate through the dataset a set number (`epochs_between_evals`) of times
# during each training session. # during each training session.
ds = ds.repeat(flags.epochs_between_evals) ds = ds.repeat(flags_obj.epochs_between_evals)
return ds return ds
def eval_input_fn(): def eval_input_fn():
return dataset.test(flags.data_dir).batch( return dataset.test(flags_obj.data_dir).batch(
flags.batch_size).make_one_shot_iterator().get_next() flags_obj.batch_size).make_one_shot_iterator().get_next()
# Set up hook that outputs training logs every 100 steps. # Set up hook that outputs training logs every 100 steps.
train_hooks = hooks_helper.get_train_hooks( train_hooks = hooks_helper.get_train_hooks(
flags.hooks, batch_size=flags.batch_size) flags_obj.hooks, batch_size=flags_obj.batch_size)
# Train and evaluate model. # Train and evaluate model.
for _ in range(flags.train_epochs // flags.epochs_between_evals): for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks) mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results) print('\nEvaluation results:\n\t%s\n' % eval_results)
if model_helpers.past_stop_threshold(flags.stop_threshold, if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
eval_results['accuracy']): eval_results['accuracy']):
break break
# Export the model # Export the model
if flags.export_dir is not None: if flags_obj.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28]) image = tf.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': image, 'image': image,
}) })
mnist_classifier.export_savedmodel(flags.export_dir, input_fn) mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn)
class MNISTArgParser(argparse.ArgumentParser):
"""Argument parser for running MNIST model."""
def __init__(self):
super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.ImageModelParser(),
])
self.set_defaults(
data_dir='/tmp/mnist_data',
model_dir='/tmp/mnist_model',
batch_size=100,
train_epochs=40)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
main(argv=sys.argv) define_mnist_flags()
absl_app.run(main)
...@@ -26,17 +26,20 @@ from __future__ import absolute_import ...@@ -26,17 +26,20 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import os import os
import sys import sys
import time import time
import tensorflow as tf # pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
import tensorflow.contrib.eager as tfe # pylint: disable=g-bad-import-order from absl import app as absl_app
from absl import flags
import tensorflow as tf
import tensorflow.contrib.eager as tfe
# pylint: enable=g-bad-import-order
from official.mnist import dataset as mnist_dataset from official.mnist import dataset as mnist_dataset
from official.mnist import mnist from official.mnist import mnist
from official.utils.arg_parsers import parsers from official.utils.flags import core as flags_core
def loss(logits, labels): def loss(logits, labels):
...@@ -95,38 +98,36 @@ def test(model, dataset): ...@@ -95,38 +98,36 @@ def test(model, dataset):
tf.contrib.summary.scalar('accuracy', accuracy.result()) tf.contrib.summary.scalar('accuracy', accuracy.result())
def main(argv): def main(flags_obj):
parser = MNISTEagerArgParser()
flags = parser.parse_args(args=argv[1:])
tf.enable_eager_execution() tf.enable_eager_execution()
# Automatically determine device and data_format # Automatically determine device and data_format
(device, data_format) = ('/gpu:0', 'channels_first') (device, data_format) = ('/gpu:0', 'channels_first')
if flags.no_gpu or not tf.test.is_gpu_available(): if flags_obj.no_gpu or tf.test.is_gpu_available():
(device, data_format) = ('/cpu:0', 'channels_last') (device, data_format) = ('/cpu:0', 'channels_last')
# If data_format is defined in FLAGS, overwrite automatically set value. # If data_format is defined in FLAGS, overwrite automatically set value.
if flags.data_format is not None: if flags_obj.data_format is not None:
data_format = flags.data_format data_format = flags_obj.data_format
print('Using device %s, and data format %s.' % (device, data_format)) print('Using device %s, and data format %s.' % (device, data_format))
# Load the datasets # Load the datasets
train_ds = mnist_dataset.train(flags.data_dir).shuffle(60000).batch( train_ds = mnist_dataset.train(flags_obj.data_dir).shuffle(60000).batch(
flags.batch_size) flags_obj.batch_size)
test_ds = mnist_dataset.test(flags.data_dir).batch(flags.batch_size) test_ds = mnist_dataset.test(flags_obj.data_dir).batch(
flags_obj.batch_size)
# Create the model and optimizer # Create the model and optimizer
model = mnist.create_model(data_format) model = mnist.create_model(data_format)
optimizer = tf.train.MomentumOptimizer(flags.lr, flags.momentum) optimizer = tf.train.MomentumOptimizer(flags_obj.lr, flags_obj.momentum)
# Create file writers for writing TensorBoard summaries. # Create file writers for writing TensorBoard summaries.
if flags.output_dir: if flags_obj.output_dir:
# Create directories to which summaries will be written # Create directories to which summaries will be written
# tensorboard --logdir=<output_dir> # tensorboard --logdir=<output_dir>
# can then be used to see the recorded summaries. # can then be used to see the recorded summaries.
train_dir = os.path.join(flags.output_dir, 'train') train_dir = os.path.join(flags_obj.output_dir, 'train')
test_dir = os.path.join(flags.output_dir, 'eval') test_dir = os.path.join(flags_obj.output_dir, 'eval')
tf.gfile.MakeDirs(flags.output_dir) tf.gfile.MakeDirs(flags_obj.output_dir)
else: else:
train_dir = None train_dir = None
test_dir = None test_dir = None
...@@ -136,19 +137,20 @@ def main(argv): ...@@ -136,19 +137,20 @@ def main(argv):
test_dir, flush_millis=10000, name='test') test_dir, flush_millis=10000, name='test')
# Create and restore checkpoint (if one exists on the path) # Create and restore checkpoint (if one exists on the path)
checkpoint_prefix = os.path.join(flags.model_dir, 'ckpt') checkpoint_prefix = os.path.join(flags_obj.model_dir, 'ckpt')
step_counter = tf.train.get_or_create_global_step() step_counter = tf.train.get_or_create_global_step()
checkpoint = tfe.Checkpoint( checkpoint = tfe.Checkpoint(
model=model, optimizer=optimizer, step_counter=step_counter) model=model, optimizer=optimizer, step_counter=step_counter)
# Restore variables on creation if a checkpoint exists. # Restore variables on creation if a checkpoint exists.
checkpoint.restore(tf.train.latest_checkpoint(flags.model_dir)) checkpoint.restore(tf.train.latest_checkpoint(flags_obj.model_dir))
# Train and evaluate for a set number of epochs. # Train and evaluate for a set number of epochs.
with tf.device(device): with tf.device(device):
for _ in range(flags.train_epochs): for _ in range(flags_obj.train_epochs):
start = time.time() start = time.time()
with summary_writer.as_default(): with summary_writer.as_default():
train(model, optimizer, train_ds, step_counter, flags.log_interval) train(model, optimizer, train_ds, step_counter,
flags_obj.log_interval)
end = time.time() end = time.time()
print('\nTrain time for epoch #%d (%d total steps): %f' % print('\nTrain time for epoch #%d (%d total steps): %f' %
(checkpoint.save_counter.numpy() + 1, (checkpoint.save_counter.numpy() + 1,
...@@ -159,50 +161,37 @@ def main(argv): ...@@ -159,50 +161,37 @@ def main(argv):
checkpoint.save(checkpoint_prefix) checkpoint.save(checkpoint_prefix)
class MNISTEagerArgParser(argparse.ArgumentParser): def define_mnist_eager_flags():
"""Argument parser for running MNIST model with eager training loop.""" """Defined flags and defaults for MNIST in eager mode."""
flags_core.define_base_eager()
def __init__(self): flags_core.define_image()
super(MNISTEagerArgParser, self).__init__(parents=[ flags.adopt_module_key_flags(flags_core)
parsers.EagerParser(),
parsers.ImageModelParser()]) flags.DEFINE_integer(
name='log_interval', short_name='li', default=10,
self.add_argument( help=flags_core.help_wrap('batches between logging training status'))
'--log_interval', '-li',
type=int, flags.DEFINE_string(
default=10, name='output_dir', short_name='od', default=None,
metavar='N', help=flags_core.help_wrap('Directory to write TensorBoard summaries'))
help='[default: %(default)s] batches between logging training status')
self.add_argument( flags.DEFINE_float(name='learning_rate', short_name='lr', default=0.01,
'--output_dir', '-od', help=flags_core.help_wrap('Learning rate.'))
type=str,
default=None, flags.DEFINE_float(name='momentum', short_name='m', default=0.5,
metavar='<OD>', help=flags_core.help_wrap('SGD momentum.'))
help='[default: %(default)s] Directory to write TensorBoard summaries')
self.add_argument( flags.DEFINE_bool(name='no_gpu', short_name='nogpu', default=False,
'--lr', '-lr', help=flags_core.help_wrap(
type=float, 'disables GPU usage even if a GPU is available'))
default=0.01,
metavar='<LR>', flags_core.set_defaults(
help='[default: %(default)s] learning rate') data_dir='/tmp/tensorflow/mnist/input_data',
self.add_argument( model_dir='/tmp/tensorflow/mnist/checkpoints/',
'--momentum', '-m', batch_size=100,
type=float, train_epochs=10,
default=0.5, )
metavar='<M>',
help='[default: %(default)s] SGD momentum')
self.add_argument(
'--no_gpu', '-nogpu',
action='store_true',
default=False,
help='disables GPU usage even if a GPU is available')
self.set_defaults(
data_dir='/tmp/tensorflow/mnist/input_data',
model_dir='/tmp/tensorflow/mnist/checkpoints/',
batch_size=100,
train_epochs=10,
)
if __name__ == '__main__': if __name__ == '__main__':
main(argv=sys.argv) define_mnist_eager_flags()
absl_app.run(main=main)
...@@ -21,8 +21,11 @@ from __future__ import print_function ...@@ -21,8 +21,11 @@ from __future__ import print_function
import os import os
import sys import sys
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.resnet import resnet_model from official.resnet import resnet_model
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
...@@ -224,25 +227,27 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -224,25 +227,27 @@ def cifar10_model_fn(features, labels, mode, params):
) )
def main(argv): def define_cifar_flags():
parser = resnet_run_loop.ResnetArgParser() resnet_run_loop.define_resnet_flags()
# Set defaults that are reasonable for this model. flags.adopt_module_key_flags(resnet_run_loop)
parser.set_defaults(data_dir='/tmp/cifar10_data', flags_core.set_defaults(data_dir='/tmp/cifar10_data',
model_dir='/tmp/cifar10_model', model_dir='/tmp/cifar10_model',
resnet_size=32, resnet_size='32',
train_epochs=250, train_epochs=250,
epochs_between_evals=10, epochs_between_evals=10,
batch_size=128) batch_size=128)
flags = parser.parse_args(args=argv[1:])
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn def main(flags_obj):
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
or input_fn)
resnet_run_loop.resnet_main( resnet_run_loop.resnet_main(
flags, cifar10_model_fn, input_function, flags_obj, cifar10_model_fn, input_function,
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS]) shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
main(argv=sys.argv) define_cifar_flags()
absl_app.run(main)
...@@ -37,6 +37,11 @@ class BaseTest(tf.test.TestCase): ...@@ -37,6 +37,11 @@ class BaseTest(tf.test.TestCase):
"""Tests for the Cifar10 version of Resnet. """Tests for the Cifar10 version of Resnet.
""" """
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
cifar10_main.define_cifar_flags()
def tearDown(self): def tearDown(self):
super(BaseTest, self).tearDown() super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir()) tf.gfile.DeleteRecursively(self.get_temp_dir())
......
...@@ -21,8 +21,11 @@ from __future__ import print_function ...@@ -21,8 +21,11 @@ from __future__ import print_function
import os import os
import sys import sys
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags import core as flags_core
from official.resnet import imagenet_preprocessing from official.resnet import imagenet_preprocessing
from official.resnet import resnet_model from official.resnet import resnet_model
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
...@@ -303,23 +306,23 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -303,23 +306,23 @@ def imagenet_model_fn(features, labels, mode, params):
) )
def main(argv): def define_imagenet_flags():
parser = resnet_run_loop.ResnetArgParser( resnet_run_loop.define_resnet_flags(
resnet_size_choices=[18, 34, 50, 101, 152, 200]) resnet_size_choices=['18', '34', '50', '101', '152', '200'])
flags.adopt_module_key_flags(resnet_run_loop)
flags_core.set_defaults(train_epochs=100)
parser.set_defaults(
train_epochs=100
)
flags = parser.parse_args(args=argv[1:])
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn def main(flags_obj):
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
or input_fn)
resnet_run_loop.resnet_main( resnet_run_loop.resnet_main(
flags, imagenet_model_fn, input_function, flags_obj, imagenet_model_fn, input_function,
shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS]) shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
main(argv=sys.argv) define_imagenet_flags()
absl_app.run(main)
...@@ -32,6 +32,11 @@ _LABEL_CLASSES = 1001 ...@@ -32,6 +32,11 @@ _LABEL_CLASSES = 1001
class BaseTest(tf.test.TestCase): class BaseTest(tf.test.TestCase):
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
imagenet_main.define_imagenet_flags()
def tearDown(self): def tearDown(self):
super(BaseTest, self).tearDown() super(BaseTest, self).tearDown()
tf.gfile.DeleteRecursively(self.get_temp_dir()) tf.gfile.DeleteRecursively(self.get_temp_dir())
......
...@@ -26,16 +26,20 @@ from __future__ import print_function ...@@ -26,16 +26,20 @@ from __future__ import print_function
import argparse import argparse
import os import os
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import resnet_model from official.resnet import resnet_model
from official.utils.arg_parsers import parsers from official.utils.flags import core as flags_core
from official.utils.export import export from official.utils.export import export
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import model_helpers from official.utils.misc import model_helpers
FLAGS = flags.FLAGS
################################################################################ ################################################################################
# Functions for input processing. # Functions for input processing.
################################################################################ ################################################################################
...@@ -346,12 +350,12 @@ def validate_batch_size_for_multi_gpu(batch_size): ...@@ -346,12 +350,12 @@ def validate_batch_size_for_multi_gpu(batch_size):
raise ValueError(err) raise ValueError(err)
def resnet_main(flags, model_function, input_function, shape=None): def resnet_main(flags_obj, model_function, input_function, shape=None):
"""Shared main loop for ResNet Models. """Shared main loop for ResNet Models.
Args: Args:
flags: FLAGS object that contains the params for running. See flags_obj: An object containing parsed flags. See define_resnet_flags()
ResnetArgParser for created flags. for details.
model_function: the function that instantiates the Model and builds the model_function: the function that instantiates the Model and builds the
ops for train/eval. This will be passed directly into the estimator. ops for train/eval. This will be passed directly into the estimator.
input_function: the function that processes the dataset and returns a input_function: the function that processes the dataset and returns a
...@@ -364,8 +368,8 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -364,8 +368,8 @@ def resnet_main(flags, model_function, input_function, shape=None):
# Using the Winograd non-fused algorithms provides a small performance boost. # Using the Winograd non-fused algorithms provides a small performance boost.
os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'
if flags.multi_gpu: if flags_obj.multi_gpu:
validate_batch_size_for_multi_gpu(flags.batch_size) validate_batch_size_for_multi_gpu(flags_obj.batch_size)
# There are two steps required if using multi-GPU: (1) wrap the model_fn, # There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens # and (2) wrap the optimizer. The first happens here, and (2) happens
...@@ -379,49 +383,50 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -379,49 +383,50 @@ def resnet_main(flags, model_function, input_function, shape=None):
# allow_soft_placement = True, which is required for multi-GPU and not # allow_soft_placement = True, which is required for multi-GPU and not
# harmful for other modes. # harmful for other modes.
session_config = tf.ConfigProto( session_config = tf.ConfigProto(
inter_op_parallelism_threads=flags.inter_op_parallelism_threads, inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
intra_op_parallelism_threads=flags.intra_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
allow_soft_placement=True) allow_soft_placement=True)
# Set up a RunConfig to save checkpoint and set session config. # Set up a RunConfig to save checkpoint and set session config.
run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9, run_config = tf.estimator.RunConfig().replace(save_checkpoints_secs=1e9,
session_config=session_config) session_config=session_config)
classifier = tf.estimator.Estimator( classifier = tf.estimator.Estimator(
model_fn=model_function, model_dir=flags.model_dir, config=run_config, model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config,
params={ params={
'resnet_size': flags.resnet_size, 'resnet_size': int(flags_obj.resnet_size),
'data_format': flags.data_format, 'data_format': flags_obj.data_format,
'batch_size': flags.batch_size, 'batch_size': flags_obj.batch_size,
'multi_gpu': flags.multi_gpu, 'multi_gpu': flags_obj.multi_gpu,
'version': flags.version, 'version': int(flags_obj.version),
'loss_scale': flags.loss_scale, 'loss_scale': flags_core.get_loss_scale(flags_obj),
'dtype': flags.dtype 'dtype': flags_core.get_tf_dtype(flags_obj)
}) })
benchmark_logger = logger.config_benchmark_logger(flags.benchmark_log_dir) benchmark_logger = logger.config_benchmark_logger(flags_obj.benchmark_log_dir)
benchmark_logger.log_run_info('resnet') benchmark_logger.log_run_info('resnet')
train_hooks = hooks_helper.get_train_hooks( train_hooks = hooks_helper.get_train_hooks(
flags.hooks, flags_obj.hooks,
batch_size=flags.batch_size, batch_size=flags_obj.batch_size,
benchmark_log_dir=flags.benchmark_log_dir) benchmark_log_dir=flags_obj.benchmark_log_dir)
def input_fn_train(): def input_fn_train():
return input_function(True, flags.data_dir, flags.batch_size, return input_function(True, flags_obj.data_dir, flags_obj.batch_size,
flags.epochs_between_evals, flags_obj.epochs_between_evals,
flags.num_parallel_calls, flags.multi_gpu) flags_obj.num_parallel_calls, flags_obj.multi_gpu)
def input_fn_eval(): def input_fn_eval():
return input_function(False, flags.data_dir, flags.batch_size, return input_function(False, flags_obj.data_dir, flags_obj.batch_size,
1, flags.num_parallel_calls, flags.multi_gpu) 1, flags_obj.num_parallel_calls, flags_obj.multi_gpu)
total_training_cycle = flags.train_epochs // flags.epochs_between_evals total_training_cycle = (flags_obj.train_epochs //
flags_obj.epochs_between_evals)
for cycle_index in range(total_training_cycle): for cycle_index in range(total_training_cycle):
tf.logging.info('Starting a training cycle: %d/%d', tf.logging.info('Starting a training cycle: %d/%d',
cycle_index, total_training_cycle) cycle_index, total_training_cycle)
classifier.train(input_fn=input_fn_train, hooks=train_hooks, classifier.train(input_fn=input_fn_train, hooks=train_hooks,
max_steps=flags.max_train_steps) max_steps=flags_obj.max_train_steps)
tf.logging.info('Starting to evaluate.') tf.logging.info('Starting to evaluate.')
# flags.max_train_steps is generally associated with testing and profiling. # flags.max_train_steps is generally associated with testing and profiling.
...@@ -431,21 +436,21 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -431,21 +436,21 @@ def resnet_main(flags, model_function, input_function, shape=None):
# Note that eval will run for max_train_steps each loop, regardless of the # Note that eval will run for max_train_steps each loop, regardless of the
# global_step count. # global_step count.
eval_results = classifier.evaluate(input_fn=input_fn_eval, eval_results = classifier.evaluate(input_fn=input_fn_eval,
steps=flags.max_train_steps) steps=flags_obj.max_train_steps)
benchmark_logger.log_evaluation_result(eval_results) benchmark_logger.log_evaluation_result(eval_results)
if model_helpers.past_stop_threshold( if model_helpers.past_stop_threshold(
flags.stop_threshold, eval_results['accuracy']): flags_obj.stop_threshold, eval_results['accuracy']):
break break
if flags.export_dir is not None: if flags_obj.export_dir is not None:
warn_on_multi_gpu_export(flags.multi_gpu) warn_on_multi_gpu_export(flags_obj.multi_gpu)
# Exports a saved model for the given classifier. # Exports a saved model for the given classifier.
input_receiver_fn = export.build_tensor_serving_input_receiver_fn( input_receiver_fn = export.build_tensor_serving_input_receiver_fn(
shape, batch_size=flags.batch_size) shape, batch_size=flags_obj.batch_size)
classifier.export_savedmodel(flags.export_dir, input_receiver_fn) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn)
def warn_on_multi_gpu_export(multi_gpu=False): def warn_on_multi_gpu_export(multi_gpu=False):
...@@ -458,35 +463,25 @@ def warn_on_multi_gpu_export(multi_gpu=False): ...@@ -458,35 +463,25 @@ def warn_on_multi_gpu_export(multi_gpu=False):
'try exporting the SavedModel with multi-GPU mode turned off.') 'try exporting the SavedModel with multi-GPU mode turned off.')
class ResnetArgParser(argparse.ArgumentParser): def define_resnet_flags(resnet_size_choices=None):
"""Arguments for configuring and running a Resnet Model.""" """Add flags and validators for ResNet."""
flags_core.define_base()
def __init__(self, resnet_size_choices=None): flags_core.define_performance()
super(ResnetArgParser, self).__init__(parents=[ flags_core.define_image()
parsers.BaseParser(), flags_core.define_benchmark()
parsers.PerformanceParser(), flags.adopt_module_key_flags(flags_core)
parsers.ImageModelParser(),
parsers.BenchmarkParser(),
])
self.add_argument( flags.DEFINE_enum(
'--version', '-v', type=int, choices=[1, 2], name='version', short_name='rv', default='2', enum_values=['1', '2'],
default=resnet_model.DEFAULT_VERSION, help=flags_core.help_wrap(
help='Version of ResNet. (1 or 2) See README.md for details.' 'Version of ResNet. (1 or 2) See README.md for details.'))
)
self.add_argument(
'--resnet_size', '-rs', type=int, default=50,
choices=resnet_size_choices,
help='[default: %(default)s] The size of the ResNet model to use.',
metavar='<RS>' if resnet_size_choices is None else None
)
def parse_args(self, args=None, namespace=None): choice_kwargs = dict(
args = super(ResnetArgParser, self).parse_args( name='resnet_size', short_name='rs', default='50',
args=args, namespace=namespace) help=flags_core.help_wrap('The size of the ResNet model to use.'))
# handle coupling between dtype and loss_scale if resnet_size_choices is None:
parsers.parse_dtype_info(args) flags.DEFINE_string(**choice_kwargs)
else:
return args flags.DEFINE_enum(enum_values=resnet_size_choices, **choice_kwargs)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Collection of parsers which are shared among the official models.
The parsers in this module are intended to be used as parents to all arg
parsers in official models. For instance, one might define a new class:
class ExampleParser(argparse.ArgumentParser):
def __init__(self):
super(ExampleParser, self).__init__(parents=[
arg_parsers.LocationParser(data_dir=True, model_dir=True),
arg_parsers.DummyParser(use_synthetic_data=True),
])
self.add_argument(
"--application_specific_arg", "-asa", type=int, default=123,
help="[default: %(default)s] This arg is application specific.",
metavar="<ASA>"
)
Notes about add_argument():
Argparse will automatically template in default values in help messages if
the "%(default)s" string appears in the message. Using the example above:
parser = ExampleParser()
parser.set_defaults(application_specific_arg=3141592)
parser.parse_args(["-h"])
When the help text is generated, it will display 3141592 to the user. (Even
though the default was 123 when the flag was created.)
The metavar variable determines how the flag will appear in help text. If
not specified, the convention is to use name.upper(). Thus rather than:
--app_specific_arg APP_SPECIFIC_ARG, -asa APP_SPECIFIC_ARG
if metavar="<ASA>" is set, the user sees:
--app_specific_arg <ASA>, -asa <ASA>
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import tensorflow as tf
# Map string to (TensorFlow dtype, default loss scale)
DTYPE_MAP = {
"fp16": (tf.float16, 128),
"fp32": (tf.float32, 1),
}
def parse_dtype_info(flags):
"""Convert dtype string to tf dtype, and set loss_scale default as needed.
Args:
flags: namespace object returned by arg parser.
Raises:
ValueError: If an invalid dtype is provided.
"""
if flags.dtype in (i[0] for i in DTYPE_MAP.values()):
return # Make function idempotent
try:
flags.dtype, default_loss_scale = DTYPE_MAP[flags.dtype]
except KeyError:
raise ValueError("Invalid dtype: {}".format(flags.dtype))
flags.loss_scale = flags.loss_scale or default_loss_scale
class BaseParser(argparse.ArgumentParser):
"""Parser to contain flags which will be nearly universal across models.
Args:
add_help: Create the "--help" flag. False if class instance is a parent.
data_dir: Create a flag for specifying the input data directory.
model_dir: Create a flag for specifying the model file directory.
train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other
eval metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs.
hooks: Create a flag to specify hooks for logging.
export_dir: Create a flag to specify where a SavedModel should be exported.
"""
def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_between_evals=True,
stop_threshold=True, batch_size=True, multi_gpu=True,
hooks=True, export_dir=True):
super(BaseParser, self).__init__(add_help=add_help)
if data_dir:
self.add_argument(
"--data_dir", "-dd", default="/tmp",
help="[default: %(default)s] The location of the input data.",
metavar="<DD>",
)
if model_dir:
self.add_argument(
"--model_dir", "-md", default="/tmp",
help="[default: %(default)s] The location of the model checkpoint "
"files.",
metavar="<MD>",
)
if train_epochs:
self.add_argument(
"--train_epochs", "-te", type=int, default=1,
help="[default: %(default)s] The number of epochs used to train.",
metavar="<TE>"
)
if epochs_between_evals:
self.add_argument(
"--epochs_between_evals", "-ebe", type=int, default=1,
help="[default: %(default)s] The number of training epochs to run "
"between evaluations.",
metavar="<EBE>"
)
if stop_threshold:
self.add_argument(
"--stop_threshold", "-st", type=float, default=None,
help="[default: %(default)s] If passed, training will stop at "
"the earlier of train_epochs and when the evaluation metric is "
"greater than or equal to stop_threshold.",
metavar="<ST>"
)
if batch_size:
self.add_argument(
"--batch_size", "-bs", type=int, default=32,
help="[default: %(default)s] Batch size for training and evaluation.",
metavar="<BS>"
)
if multi_gpu:
self.add_argument(
"--multi_gpu", action="store_true",
help="If set, run across all available GPUs."
)
if hooks:
self.add_argument(
"--hooks", "-hk", nargs="+", default=["LoggingTensorHook"],
help="[default: %(default)s] A list of strings to specify the names "
"of train hooks. "
"Example: --hooks LoggingTensorHook ExamplesPerSecondHook. "
"Allowed hook names (case-insensitive): LoggingTensorHook, "
"ProfilerHook, ExamplesPerSecondHook, LoggingMetricHook."
"See official.utils.logs.hooks_helper for details.",
metavar="<HK>"
)
if export_dir:
self.add_argument(
"--export_dir", "-ed",
help="[default: %(default)s] If set, a SavedModel serialization of "
"the model will be exported to this directory at the end of "
"training. See the README for more details and relevant links.",
metavar="<ED>"
)
class PerformanceParser(argparse.ArgumentParser):
"""Default parser for specifying performance tuning arguments.
Args:
add_help: Create the "--help" flag. False if class instance is a parent.
num_parallel_calls: Create a flag to specify parallelism of data loading.
inter_op: Create a flag to allow specification of inter op threads.
intra_op: Create a flag to allow specification of intra op threads.
"""
def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True, max_train_steps=True,
dtype=True):
super(PerformanceParser, self).__init__(add_help=add_help)
if num_parallel_calls:
self.add_argument(
"--num_parallel_calls", "-npc",
type=int, default=5,
help="[default: %(default)s] The number of records that are "
"processed in parallel during input processing. This can be "
"optimized per data set but for generally homogeneous data "
"sets, should be approximately the number of available CPU "
"cores.",
metavar="<NPC>"
)
if inter_op:
self.add_argument(
"--inter_op_parallelism_threads", "-inter",
type=int, default=0,
help="[default: %(default)s Number of inter_op_parallelism_threads "
"to use for CPU. See TensorFlow config.proto for details.",
metavar="<INTER>"
)
if intra_op:
self.add_argument(
"--intra_op_parallelism_threads", "-intra",
type=int, default=0,
help="[default: %(default)s Number of intra_op_parallelism_threads "
"to use for CPU. See TensorFlow config.proto for details.",
metavar="<INTRA>"
)
if use_synthetic_data:
self.add_argument(
"--use_synthetic_data", "-synth",
action="store_true",
help="If set, use fake data (zeroes) instead of a real dataset. "
"This mode is useful for performance debugging, as it removes "
"input processing steps, but will not learn anything."
)
if max_train_steps:
self.add_argument(
"--max_train_steps", "-mts", type=int, default=None,
help="[default: %(default)s] The model will stop training if the "
"global_step reaches this value. If not set, training will run"
"until the specified number of epochs have run as usual. It is"
"generally recommended to set --train_epochs=1 when using this"
"flag.",
metavar="<MTS>"
)
if dtype:
self.add_argument(
"--dtype", "-dt",
default="fp32",
choices=list(DTYPE_MAP.keys()),
help="[default: %(default)s] {%(choices)s} The TensorFlow datatype "
"used for calculations. Variables may be cast to a higher"
"precision on a case-by-case basis for numerical stability.",
metavar="<DT>"
)
self.add_argument(
"--loss_scale", "-ls",
type=int,
help="[default: %(default)s] The amount to scale the loss by when "
"the model is run. Before gradients are computed, the loss is "
"multiplied by the loss scale, making all gradients loss_scale "
"times larger. To adjust for this, gradients are divided by the "
"loss scale before being applied to variables. This is "
"mathematically equivalent to training without a loss scale, "
"but the loss scale helps avoid some intermediate gradients "
"from underflowing to zero. If not provided the default for "
"fp16 is 128 and 1 for all other dtypes.",
)
class ImageModelParser(argparse.ArgumentParser):
"""Default parser for specification image specific behavior.
Args:
add_help: Create the "--help" flag. False if class instance is a parent.
data_format: Create a flag to specify image axis convention.
"""
def __init__(self, add_help=False, data_format=True):
super(ImageModelParser, self).__init__(add_help=add_help)
if data_format:
self.add_argument(
"--data_format", "-df",
default=None,
choices=["channels_first", "channels_last"],
help="A flag to override the data format used in the model. "
"channels_first provides a performance boost on GPU but is not "
"always compatible with CPU. If left unspecified, the data "
"format will be chosen automatically based on whether TensorFlow"
"was built for CPU or GPU.",
metavar="<CF>"
)
class BenchmarkParser(argparse.ArgumentParser):
"""Default parser for benchmark logging.
Args:
add_help: Create the "--help" flag. False if class instance is a parent.
benchmark_log_dir: Create a flag to specify location for benchmark logging.
"""
def __init__(self, add_help=False, benchmark_log_dir=True,
bigquery_uploader=True):
super(BenchmarkParser, self).__init__(add_help=add_help)
if benchmark_log_dir:
self.add_argument(
"--benchmark_log_dir", "-bld", default=None,
help="[default: %(default)s] The location of the benchmark logging.",
metavar="<BLD>"
)
if bigquery_uploader:
self.add_argument(
"--gcp_project", "-gp", default=None,
help="[default: %(default)s] The GCP project name where the benchmark"
" will be uploaded.",
metavar="<GP>"
)
self.add_argument(
"--bigquery_data_set", "-bds", default="test_benchmark",
help="[default: %(default)s] The Bigquery dataset name where the"
" benchmark will be uploaded.",
metavar="<BDS>"
)
self.add_argument(
"--bigquery_run_table", "-brt", default="benchmark_run",
help="[default: %(default)s] The Bigquery table name where the"
" benchmark run information will be uploaded.",
metavar="<BRT>"
)
self.add_argument(
"--bigquery_metric_table", "-bmt", default="benchmark_metric",
help="[default: %(default)s] The Bigquery table name where the"
" benchmark metric information will be uploaded.",
metavar="<BMT>"
)
class EagerParser(BaseParser):
"""Remove options not relevant for Eager from the BaseParser."""
def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, batch_size=True):
super(EagerParser, self).__init__(
add_help=add_help, data_dir=data_dir, model_dir=model_dir,
train_epochs=train_epochs, epochs_between_evals=False,
stop_threshold=False, batch_size=batch_size, multi_gpu=False,
hooks=False)
# Adding Abseil (absl) flags quickstart
## Defining a flag
absl flag definitions are similar to argparse, although they are defined on a global namespace.
For instance defining a string flag looks like:
```$xslt
from absl import flags
flags.DEFINE_string(
name="my_flag",
default="a_sensible_default",
help="Here is what this flag does."
)
```
All three arguments are required, but default may be `None`. A common optional argument is
short_name for defining abreviations. Certain `DEFINE_*` methods will have other required arguments.
For instance `DEFINE_enum` requires the `enum_values` argument to be specified.
## Key Flags
absl has the concept of a key flag. Any flag defined in `__main__` is considered a key flag by
default. Key flags are displayed in `--help`, others only appear in `--helpfull`. In order to
handle key flags that are defined outside the module in question, absl provides the
`flags.adopt_module_key_flags()` method. This adds the key flags of a different module to one's own
key flags. For example:
```$xslt
File: flag_source.py
---------------------------------------
from absl import flags
flags.DEFINE_string(name="my_flag", default="abc", help="a flag.")
```
```$xslt
File: my_module.py
---------------------------------------
from absl import app as absl_app
from absl import flags
import flag_source
flags.adopt_module_key_flags(flag_source)
def main(_):
pass
absl_app.run(main, [__file__, "-h"]
```
when `my_module.py` is run it will show the help text for `my_flag`. Because not all flags defined
in a file are equally important, `official/utils/flags/core.py` (generally imported as flags_core)
provides an abstraction for handling key flag declaration in an easy way through the
`register_key_flags_in_core()` function, which allows a module to make a single
`adopt_key_flags(flags_core)` call when using the util flag declaration functions.
## Validators
Often the constraints on a flag are complicated. absl provides the validator decorator to allow
one to mark a function as a flag validation function. Suppose we want users to provide a flag
which is a palindrome.
```$xslt
from absl import flags
flags.DEFINE_string(name="pal_flag", short_name="pf", default="", help="Give me a palindrome")
@flags.validator("pal_flag")
def _check_pal(provided_pal_flag):
return provided_pal_flag == provided_pal_flag[::-1]
```
Validators take the form that returning True (truthy) passes, and all others
(False, None, exception) fail.
## Common Flags
Common flags (i.e. batch_size, model_dir, etc.) are provided by various flag definition functions,
and channeled through `official.utils.flags.core`. For instance to define common supervised
learning parameters one could use the following code:
```$xslt
from absl import app as absl_app
from absl import flags
from official.utils.flags import core as flags_core
def define_flags():
flags_core.define_base()
flags.adopt_key_flags(flags_core)
def main(flags_obj):
pass
if __name__ == "__main__"
absl_app.run(main)
```
## Testing
To test using absl, simply declare flags in the setupClass method of TensorFlow's TestCase.
```$xslt
from absl import flags
import tensorflow as tf
def define_flags():
flags.DEFINE_string(name="test_flag", default="abc", help="an example flag")
class BaseTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(BaseTester, cls).setUpClass()
define_flags()
def test_trivial(self):
flags_core.parse_flags([__file__, "test_flag", "def"])
self.AssertEqual(flags.FLAGS.test_flag, "def")
```
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags which will be nearly universal across models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from official.utils.flags._conventions import help_wrap
from official.utils.logs import hooks_helper
def define_base(data_dir=True, model_dir=True, train_epochs=True,
epochs_between_evals=True, stop_threshold=True, batch_size=True,
multi_gpu=True, hooks=True, export_dir=True):
"""Register base flags.
Args:
data_dir: Create a flag for specifying the input data directory.
model_dir: Create a flag for specifying the model file directory.
train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other
eval metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs.
hooks: Create a flag to specify hooks for logging.
export_dir: Create a flag to specify where a SavedModel should be exported.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
if data_dir:
flags.DEFINE_string(
name="data_dir", short_name="dd", default="/tmp",
help=help_wrap("The location of the input data."))
key_flags.append("data_dir")
if model_dir:
flags.DEFINE_string(
name="model_dir", short_name="md", default="/tmp",
help=help_wrap("The location of the model checkpoint files."))
key_flags.append("model_dir")
if train_epochs:
flags.DEFINE_integer(
name="train_epochs", short_name="te", default=1,
help=help_wrap("The number of epochs used to train."))
key_flags.append("train_epochs")
if epochs_between_evals:
flags.DEFINE_integer(
name="epochs_between_evals", short_name="ebe", default=1,
help=help_wrap("The number of training epochs to run between "
"evaluations."))
key_flags.append("epochs_between_evals")
if stop_threshold:
flags.DEFINE_float(
name="stop_threshold", short_name="st",
default=None,
help=help_wrap("If passed, training will stop at the earlier of "
"train_epochs and when the evaluation metric is "
"greater than or equal to stop_threshold."))
if batch_size:
flags.DEFINE_integer(
name="batch_size", short_name="bs", default=32,
help=help_wrap("Batch size for training and evaluation."))
key_flags.append("batch_size")
if multi_gpu:
flags.DEFINE_bool(
name="multi_gpu", default=False,
help=help_wrap("If set, run across all available GPUs."))
key_flags.append("multi_gpu")
if hooks:
# Construct a pretty summary of hooks.
hook_list_str = (
u"\ufeff Hook:\n" + u"\n".join([u"\ufeff {}".format(key) for key
in hooks_helper.HOOKS]))
flags.DEFINE_list(
name="hooks", short_name="hk", default="LoggingTensorHook",
help=help_wrap(
u"A list of (case insensitive) strings to specify the names of "
u"training hooks.\n{}\n\ufeff Example: `--hooks ProfilerHook,"
u"ExamplesPerSecondHook`\n See official.utils.logs.hooks_helper "
u"for details.".format(hook_list_str))
)
key_flags.append("hooks")
if export_dir:
flags.DEFINE_string(
name="export_dir", short_name="ed", default=None,
help=help_wrap("If set, a SavedModel serialization of the model will "
"be exported to this directory at the end of training. "
"See the README for more details and relevant links.")
)
key_flags.append("export_dir")
return key_flags
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags for benchmarking models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from official.utils.flags._conventions import help_wrap
def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
"""Register benchmarking flags.
Args:
benchmark_log_dir: Create a flag to specify location for benchmark logging.
bigquery_uploader: Create flags for uploading results to BigQuery.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
if benchmark_log_dir:
flags.DEFINE_string(
name="benchmark_log_dir", short_name="bld", default=None,
help=help_wrap("The location of the benchmark logging.")
)
if bigquery_uploader:
flags.DEFINE_string(
name="gcp_project", short_name="gp", default=None,
help=help_wrap(
"The GCP project name where the benchmark will be uploaded."))
flags.DEFINE_string(
name="bigquery_data_set", short_name="bds", default="test_benchmark",
help=help_wrap(
"The Bigquery dataset name where the benchmark will be uploaded."))
flags.DEFINE_string(
name="bigquery_run_table", short_name="brt", default="benchmark_run",
help=help_wrap("The Bigquery table name where the benchmark run "
"information will be uploaded."))
flags.DEFINE_string(
name="bigquery_metric_table", short_name="bmt",
default="benchmark_metric",
help=help_wrap("The Bigquery table name where the benchmark metric "
"information will be uploaded."))
return key_flags
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Central location for shared arparse convention definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl import app as absl_app
from absl import flags
# This codifies help string conventions and makes it easy to update them if
# necessary. Currently the only major effect is that help bodies start on the
# line after flags are listed. All flag definitions should wrap the text bodies
# with help wrap when calling DEFINE_*.
help_wrap = functools.partial(flags.text_wrap, length=80, indent="",
firstline_indent="\n")
# Replace None with h to also allow -h
absl_app.HelpshortFlag.SHORT_NAME = "h"
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Misc flags."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from official.utils.flags._conventions import help_wrap
def define_image(data_format=True):
"""Register image specific flags.
Args:
data_format: Create a flag to specify image axis convention.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
if data_format:
flags.DEFINE_enum(
name="data_format", short_name="df", default=None,
enum_values=["channels_first", "channels_last"],
help=help_wrap(
"A flag to override the data format used in the model. "
"channels_first provides a performance boost on GPU but is not "
"always compatible with CPU. If left unspecified, the data format "
"will be chosen automatically based on whether TensorFlow was "
"built for CPU or GPU."))
key_flags.append("data_format")
return key_flags
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Register flags for optimizing performance."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
from absl import flags # pylint: disable=g-bad-import-order
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.flags._conventions import help_wrap
# Map string to (TensorFlow dtype, default loss scale)
DTYPE_MAP = {
"fp16": (tf.float16, 128),
"fp32": (tf.float32, 1),
}
def get_tf_dtype(flags_obj):
return DTYPE_MAP[flags_obj.dtype][0]
def get_loss_scale(flags_obj):
if flags_obj.loss_scale is not None:
return flags_obj.loss_scale
return DTYPE_MAP[flags_obj.dtype][1]
def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
synthetic_data=True, max_train_steps=True, dtype=True):
"""Register flags for specifying performance tuning arguments.
Args:
num_parallel_calls: Create a flag to specify parallelism of data loading.
inter_op: Create a flag to allow specification of inter op threads.
intra_op: Create a flag to allow specification of intra op threads.
synthetic_data: Create a flag to allow the use of synthetic data.
max_train_steps: Create a flags to allow specification of maximum number
of training steps
dtype: Create flags for specifying dtype.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags = []
if num_parallel_calls:
flags.DEFINE_integer(
name="num_parallel_calls", short_name="npc",
default=multiprocessing.cpu_count(),
help=help_wrap("The number of records that are processed in parallel "
"during input processing. This can be optimized per "
"data set but for generally homogeneous data sets, "
"should be approximately the number of available CPU "
"cores. (default behavior)"))
if inter_op:
flags.DEFINE_integer(
name="inter_op_parallelism_threads", short_name="inter", default=0,
help=help_wrap("Number of inter_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details.")
)
if intra_op:
flags.DEFINE_integer(
name="intra_op_parallelism_threads", short_name="intra", default=0,
help=help_wrap("Number of intra_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details."))
if synthetic_data:
flags.DEFINE_bool(
name="use_synthetic_data", short_name="synth", default=False,
help=help_wrap(
"If set, use fake data (zeroes) instead of a real dataset. "
"This mode is useful for performance debugging, as it removes "
"input processing steps, but will not learn anything."))
if max_train_steps:
flags.DEFINE_integer(
name="max_train_steps", short_name="mts", default=None, help=help_wrap(
"The model will stop training if the global_step reaches this "
"value. If not set, training will run until the specified number "
"of epochs have run as usual. It is generally recommended to set "
"--train_epochs=1 when using this flag."
))
if dtype:
flags.DEFINE_enum(
name="dtype", short_name="dt", default="fp32",
enum_values=DTYPE_MAP.keys(),
help=help_wrap("The TensorFlow datatype used for calculations. "
"Variables may be cast to a higher precision on a "
"case-by-case basis for numerical stability."))
flags.DEFINE_integer(
name="loss_scale", short_name="ls", default=None,
help=help_wrap(
"The amount to scale the loss by when the model is run. Before "
"gradients are computed, the loss is multiplied by the loss scale, "
"making all gradients loss_scale times larger. To adjust for this, "
"gradients are divided by the loss scale before being applied to "
"variables. This is mathematically equivalent to training without "
"a loss scale, but the loss scale helps avoid some intermediate "
"gradients from underflowing to zero. If not provided the default "
"for fp16 is 128 and 1 for all other dtypes."))
loss_scale_val_msg = "loss_scale should be a positive integer."
@flags.validator(flag_name="loss_scale", message=loss_scale_val_msg)
def _check_loss_scale(loss_scale): # pylint: disable=unused-variable
if loss_scale is None:
return True # null case is handled in get_loss_scale()
return loss_scale > 0
return key_flags
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Public interface for flag definition.
See _example.py for detailed instructions on defining flags.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import sys
from absl import app as absl_app
from absl import flags
from official.utils.flags import _base
from official.utils.flags import _benchmark
from official.utils.flags import _conventions
from official.utils.flags import _misc
from official.utils.flags import _performance
def set_defaults(**kwargs):
for key, value in kwargs.items():
flags.FLAGS.set_default(name=key, value=value)
def parse_flags(argv=None):
"""Reset flags and reparse. Currently only used in testing."""
flags.FLAGS.unparse_flags()
absl_app.parse_flags_with_usage(argv or sys.argv)
def register_key_flags_in_core(f):
"""Defines a function in core.py, and registers its key flags.
absl uses the location of a flags.declare_key_flag() to determine the context
in which a flag is key. By making all declares in core, this allows model
main functions to call flags.adopt_module_key_flags() on core and correctly
chain key flags.
Args:
f: The function to be wrapped
Returns:
The "core-defined" version of the input function.
"""
def core_fn(*args, **kwargs):
key_flags = f(*args, **kwargs)
[flags.declare_key_flag(fl) for fl in key_flags] # pylint: disable=expression-not-assigned
return core_fn
define_base = register_key_flags_in_core(_base.define_base)
# Remove options not relevant for Eager from define_base().
define_base_eager = register_key_flags_in_core(functools.partial(
_base.define_base, epochs_between_evals=False, stop_threshold=False,
multi_gpu=False, hooks=False))
define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark)
define_image = register_key_flags_in_core(_misc.define_image)
define_performance = register_key_flags_in_core(_performance.define_performance)
help_wrap = _conventions.help_wrap
get_tf_dtype = _performance.get_tf_dtype
get_loss_scale = _performance.get_loss_scale
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,29 +13,28 @@ ...@@ -13,29 +13,28 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import argparse
import unittest import unittest
import tensorflow as tf # pylint: disable=g-bad-import-order from absl import flags
import tensorflow as tf
from official.utils.arg_parsers import parsers from official.utils.flags import core as flags_core # pylint: disable=g-bad-import-order
class TestParser(argparse.ArgumentParser): def define_flags():
"""Class to test canned parser functionality.""" flags_core.define_base()
flags_core.define_performance()
def __init__(self): flags_core.define_image()
super(TestParser, self).__init__(parents=[ flags_core.define_benchmark()
parsers.BaseParser(),
parsers.PerformanceParser(num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True),
parsers.ImageModelParser(data_format=True),
parsers.BenchmarkParser(benchmark_log_dir=True, bigquery_uploader=True)
])
class BaseTester(unittest.TestCase): class BaseTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(BaseTester, cls).setUpClass()
define_flags()
def test_default_setting(self): def test_default_setting(self):
"""Test to ensure fields exist and defaults can be set. """Test to ensure fields exist and defaults can be set.
""" """
...@@ -49,16 +48,15 @@ class BaseTester(unittest.TestCase): ...@@ -49,16 +48,15 @@ class BaseTester(unittest.TestCase):
hooks=["LoggingTensorHook"], hooks=["LoggingTensorHook"],
num_parallel_calls=18, num_parallel_calls=18,
inter_op_parallelism_threads=5, inter_op_parallelism_threads=5,
intra_op_parallelism_thread=10, intra_op_parallelism_threads=10,
data_format="channels_first" data_format="channels_first"
) )
parser = TestParser() flags_core.set_defaults(**defaults)
parser.set_defaults(**defaults) flags_core.parse_flags()
namespace_vars = vars(parser.parse_args([]))
for key, value in defaults.items(): for key, value in defaults.items():
assert namespace_vars[key] == value assert flags.FLAGS.get_flag_value(name=key, default=None) == value
def test_benchmark_setting(self): def test_benchmark_setting(self):
defaults = dict( defaults = dict(
...@@ -67,40 +65,36 @@ class BaseTester(unittest.TestCase): ...@@ -67,40 +65,36 @@ class BaseTester(unittest.TestCase):
gcp_project="project_abc", gcp_project="project_abc",
) )
parser = TestParser() flags_core.set_defaults(**defaults)
parser.set_defaults(**defaults) flags_core.parse_flags()
namespace_vars = vars(parser.parse_args([]))
for key, value in defaults.items(): for key, value in defaults.items():
assert namespace_vars[key] == value assert flags.FLAGS.get_flag_value(name=key, default=None) == value
def test_booleans(self): def test_booleans(self):
"""Test to ensure boolean flags trigger as expected. """Test to ensure boolean flags trigger as expected.
""" """
parser = TestParser() flags_core.parse_flags([__file__, "--multi_gpu", "--use_synthetic_data"])
namespace = parser.parse_args(["--multi_gpu", "--use_synthetic_data"])
assert namespace.multi_gpu assert flags.FLAGS.multi_gpu
assert namespace.use_synthetic_data assert flags.FLAGS.use_synthetic_data
def test_parse_dtype_info(self): def test_parse_dtype_info(self):
parser = TestParser()
for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128], for dtype_str, tf_dtype, loss_scale in [["fp16", tf.float16, 128],
["fp32", tf.float32, 1]]: ["fp32", tf.float32, 1]]:
args = parser.parse_args(["--dtype", dtype_str]) flags_core.parse_flags([__file__, "--dtype", dtype_str])
parsers.parse_dtype_info(args)
assert args.dtype == tf_dtype self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf_dtype)
assert args.loss_scale == loss_scale self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), loss_scale)
args = parser.parse_args(["--dtype", dtype_str, "--loss_scale", "5"]) flags_core.parse_flags(
parsers.parse_dtype_info(args) [__file__, "--dtype", dtype_str, "--loss_scale", "5"])
assert args.loss_scale == 5 self.assertEqual(flags_core.get_loss_scale(flags.FLAGS), 5)
with self.assertRaises(SystemExit): with self.assertRaises(SystemExit):
parser.parse_args(["--dtype", "int8"]) flags_core.parse_flags([__file__, "--dtype", "int8"])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -31,9 +31,13 @@ import uuid ...@@ -31,9 +31,13 @@ import uuid
from google.cloud import bigquery from google.cloud import bigquery
import tensorflow as tf # pylint: disable=g-bad-import-order # pylint: disable=g-bad-import-order
from absl import app as absl_app
from absl import flags
import tensorflow as tf
# pylint: enable=g-bad-import-order
from official.utils.arg_parsers import parsers from official.utils.flags import core as flags_core
from official.utils.logs import logger from official.utils.logs import logger
...@@ -108,22 +112,22 @@ class BigQueryUploader(object): ...@@ -108,22 +112,22 @@ class BigQueryUploader(object):
"Failed to upload benchmark info to bigquery: {}".format(errors)) "Failed to upload benchmark info to bigquery: {}".format(errors))
def main(argv): def main(_):
parser = parsers.BenchmarkParser() if not flags.FLAGS.benchmark_log_dir:
flags = parser.parse_args(args=argv[1:])
if not flags.benchmark_log_dir:
print("Usage: benchmark_uploader.py --benchmark_log_dir=/some/dir") print("Usage: benchmark_uploader.py --benchmark_log_dir=/some/dir")
sys.exit(1) sys.exit(1)
uploader = BigQueryUploader( uploader = BigQueryUploader(
flags.benchmark_log_dir, flags.FLAGS.benchmark_log_dir,
gcp_project=flags.gcp_project) gcp_project=flags.FLAGS.gcp_project)
run_id = str(uuid.uuid4()) run_id = str(uuid.uuid4())
uploader.upload_benchmark_run( uploader.upload_benchmark_run(
flags.bigquery_data_set, flags.bigquery_run_table, run_id) flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_table, run_id)
uploader.upload_metric( uploader.upload_metric(
flags.bigquery_data_set, flags.bigquery_metric_table, run_id) flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_metric_table, run_id)
if __name__ == "__main__": if __name__ == "__main__":
main(argv=sys.argv) flags_core.define_benchmark()
flags.adopt_module_key_flags(flags_core)
absl_app.run(main=main)
...@@ -19,12 +19,15 @@ from __future__ import absolute_import ...@@ -19,12 +19,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import shutil import shutil
import sys import sys
import tempfile import tempfile
from absl import flags
from official.utils.flags import core as flags_core
def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1): def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
"""Performs a minimal run of a model. """Performs a minimal run of a model.
...@@ -55,7 +58,8 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1): ...@@ -55,7 +58,8 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
args.extend(["--max_train_steps", str(max_train)]) args.extend(["--max_train_steps", str(max_train)])
try: try:
main(args) flags_core.parse_flags(argv=args)
main(flags.FLAGS)
finally: finally:
if os.path.exists(model_dir): if os.path.exists(model_dir):
shutil.rmtree(model_dir) shutil.rmtree(model_dir)
...@@ -61,7 +61,7 @@ variable-rgx=^[a-z][a-z0-9_]*$ ...@@ -61,7 +61,7 @@ variable-rgx=^[a-z][a-z0-9_]*$
# (useful for modules/projects where namespaces are manipulated during runtime # (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It # and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching. # supports qualified module names, as well as Unix pattern matching.
ignored-modules=official, official.*, tensorflow, tensorflow.*, LazyLoader, google, google.cloud.* ignored-modules=absl, absl.*, official, official.*, tensorflow, tensorflow.*, LazyLoader, google, google.cloud.*
[CLASSES] [CLASSES]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment