Unverified Commit adfd5a3a authored by Katherine Wu's avatar Katherine Wu Committed by GitHub
Browse files

Use util functions hooks_helper and parser in mnist and wide_deep, and rename...

Use util functions hooks_helper and parser in mnist and wide_deep, and rename epochs_between_eval (from epochs_per_eval) (#3650)
parent 875fcb3b
...@@ -11,7 +11,10 @@ APIs. ...@@ -11,7 +11,10 @@ APIs.
## Setup ## Setup
To begin, you'll simply need the latest version of TensorFlow installed. To begin, you'll simply need the latest version of TensorFlow installed,
and make sure to run the command to export the `/models` folder to the
python path: https://github.com/tensorflow/models/tree/master/official#running-the-models
Then to train the model, run the following: Then to train the model, run the following:
``` ```
......
...@@ -18,12 +18,15 @@ from __future__ import division ...@@ -18,12 +18,15 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os
import sys import sys
import tensorflow as tf import tensorflow as tf
from official.mnist import dataset from official.mnist import dataset
from official.utils.arg_parsers import parsers
from official.utils.logging import hooks_helper
LEARNING_RATE = 1e-4
class Model(tf.keras.Model): class Model(tf.keras.Model):
"""Model to recognize digits in the MNIST dataset. """Model to recognize digits in the MNIST dataset.
...@@ -104,7 +107,7 @@ def model_fn(features, labels, mode, params): ...@@ -104,7 +107,7 @@ def model_fn(features, labels, mode, params):
'classify': tf.estimator.export.PredictOutput(predictions) 'classify': tf.estimator.export.PredictOutput(predictions)
}) })
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
# If we are running multi-GPU, we need to wrap the optimizer. # If we are running multi-GPU, we need to wrap the optimizer.
if params.get('multi_gpu'): if params.get('multi_gpu'):
...@@ -114,10 +117,15 @@ def model_fn(features, labels, mode, params): ...@@ -114,10 +117,15 @@ def model_fn(features, labels, mode, params):
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
accuracy = tf.metrics.accuracy( accuracy = tf.metrics.accuracy(
labels=labels, predictions=tf.argmax(logits, axis=1)) labels=labels, predictions=tf.argmax(logits, axis=1))
# Name the accuracy tensor 'train_accuracy' to demonstrate the
# LoggingTensorHook. # Name tensors to be logged with LoggingTensorHook.
tf.identity(LEARNING_RATE, 'learning_rate')
tf.identity(loss, 'cross_entropy')
tf.identity(accuracy[1], name='train_accuracy') tf.identity(accuracy[1], name='train_accuracy')
# Save accuracy scalar to Tensorboard output.
tf.summary.scalar('train_accuracy', accuracy[1]) tf.summary.scalar('train_accuracy', accuracy[1])
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.TRAIN, mode=tf.estimator.ModeKeys.TRAIN,
loss=loss, loss=loss,
...@@ -185,30 +193,32 @@ def main(unused_argv): ...@@ -185,30 +193,32 @@ def main(unused_argv):
'multi_gpu': FLAGS.multi_gpu 'multi_gpu': FLAGS.multi_gpu
}) })
# Train the model # Set up training and evaluation input functions.
def train_input_fn(): def train_input_fn():
# When choosing shuffle buffer sizes, larger sizes result in better # When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small # randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch. # enough dataset that we can easily shuffle the full epoch.
ds = dataset.train(FLAGS.data_dir) ds = dataset.train(FLAGS.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat( ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size)
FLAGS.train_epochs)
return ds
# Set up training hook that logs the training accuracy every 100 steps. # Iterate through the dataset a set number (`epochs_between_evals`) of times
tensors_to_log = {'train_accuracy': 'train_accuracy'} # during each training session.
logging_hook = tf.train.LoggingTensorHook( ds = ds.repeat(FLAGS.epochs_between_evals)
tensors=tensors_to_log, every_n_iter=100) return ds
mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])
# Evaluate the model and print results
def eval_input_fn(): def eval_input_fn():
return dataset.test(FLAGS.data_dir).batch( return dataset.test(FLAGS.data_dir).batch(
FLAGS.batch_size).make_one_shot_iterator().get_next() FLAGS.batch_size).make_one_shot_iterator().get_next()
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) # Set up hook that outputs training logs every 100 steps.
print() train_hooks = hooks_helper.get_train_hooks(
print('Evaluation results:\n\t%s' % eval_results) FLAGS.hooks, batch_size=FLAGS.batch_size)
# Train and evaluate model.
for n in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results)
# Export the model # Export the model
if FLAGS.export_dir is not None: if FLAGS.export_dir is not None:
...@@ -220,51 +230,28 @@ def main(unused_argv): ...@@ -220,51 +230,28 @@ def main(unused_argv):
class MNISTArgParser(argparse.ArgumentParser): class MNISTArgParser(argparse.ArgumentParser):
"""Argument parser for running MNIST model."""
def __init__(self): def __init__(self):
super(MNISTArgParser, self).__init__() super(MNISTArgParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.ImageModelParser()])
self.add_argument(
'--multi_gpu', action='store_true',
help='If set, run across all available GPUs.')
self.add_argument(
'--batch_size',
type=int,
default=100,
help='Number of images to process in a batch')
self.add_argument(
'--data_dir',
type=str,
default='/tmp/mnist_data',
help='Path to directory containing the MNIST dataset')
self.add_argument(
'--model_dir',
type=str,
default='/tmp/mnist_model',
help='The directory where the model will be stored.')
self.add_argument(
'--train_epochs',
type=int,
default=40,
help='Number of epochs to train.')
self.add_argument(
'--data_format',
type=str,
default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. '
'channels_first provides a performance boost on GPU but is not always '
'compatible with CPU. If left unspecified, the data format will be '
'chosen automatically based on whether TensorFlow was built for CPU or '
'GPU.')
self.add_argument( self.add_argument(
'--export_dir', '--export_dir',
type=str, type=str,
help='The directory where the exported SavedModel will be stored.') help='[default: %(default)s] If set, a SavedModel serialization of the '
'model will be exported to this directory at the end of training. '
'See the README for more details and relevant links.')
self.set_defaults(
data_dir='/tmp/mnist_data',
model_dir='/tmp/mnist_model',
batch_size=100,
train_epochs=40)
if __name__ == '__main__': if __name__ == '__main__':
parser = MNISTArgParser()
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
parser = MNISTArgParser()
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
...@@ -33,8 +33,10 @@ import time ...@@ -33,8 +33,10 @@ import time
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.eager as tfe import tensorflow.contrib.eager as tfe
from official.mnist import mnist from official.mnist import mnist
from official.mnist import dataset from official.mnist import dataset
from official.utils.arg_parsers import parsers
FLAGS = None FLAGS = None
...@@ -98,9 +100,13 @@ def test(model, dataset): ...@@ -98,9 +100,13 @@ def test(model, dataset):
def main(_): def main(_):
tfe.enable_eager_execution() tfe.enable_eager_execution()
# Automatically determine device and data_format
(device, data_format) = ('/gpu:0', 'channels_first') (device, data_format) = ('/gpu:0', 'channels_first')
if FLAGS.no_gpu or tfe.num_gpus() <= 0: if FLAGS.no_gpu or tfe.num_gpus() <= 0:
(device, data_format) = ('/cpu:0', 'channels_last') (device, data_format) = ('/cpu:0', 'channels_last')
# If data_format is defined in FLAGS, overwrite automatically set value.
if FLAGS.data_format is not None:
data_format = data_format
print('Using device %s, and data format %s.' % (device, data_format)) print('Using device %s, and data format %s.' % (device, data_format))
# Load the datasets # Load the datasets
...@@ -112,6 +118,7 @@ def main(_): ...@@ -112,6 +118,7 @@ def main(_):
model = mnist.Model(data_format) model = mnist.Model(data_format)
optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum) optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum)
# Create file writers for writing TensorBoard summaries.
if FLAGS.output_dir: if FLAGS.output_dir:
# Create directories to which summaries will be written # Create directories to which summaries will be written
# tensorboard --logdir=<output_dir> # tensorboard --logdir=<output_dir>
...@@ -126,15 +133,18 @@ def main(_): ...@@ -126,15 +133,18 @@ def main(_):
train_dir, flush_millis=10000) train_dir, flush_millis=10000)
test_summary_writer = tf.contrib.summary.create_file_writer( test_summary_writer = tf.contrib.summary.create_file_writer(
test_dir, flush_millis=10000, name='test') test_dir, flush_millis=10000, name='test')
checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
# Create and restore checkpoint (if one exists on the path)
checkpoint_prefix = os.path.join(FLAGS.model_dir, 'ckpt')
step_counter = tf.train.get_or_create_global_step() step_counter = tf.train.get_or_create_global_step()
checkpoint = tfe.Checkpoint( checkpoint = tfe.Checkpoint(
model=model, optimizer=optimizer, step_counter=step_counter) model=model, optimizer=optimizer, step_counter=step_counter)
# Restore variables on creation if a checkpoint exists. # Restore variables on creation if a checkpoint exists.
checkpoint.restore(tf.train.latest_checkpoint(FLAGS.checkpoint_dir)) checkpoint.restore(tf.train.latest_checkpoint(FLAGS.model_dir))
# Train and evaluate for 10 epochs.
# Train and evaluate for a set number of epochs.
with tf.device(device): with tf.device(device):
for _ in range(10): for _ in range(FLAGS.train_epochs):
start = time.time() start = time.time()
with summary_writer.as_default(): with summary_writer.as_default():
train(model, optimizer, train_ds, step_counter, FLAGS.log_interval) train(model, optimizer, train_ds, step_counter, FLAGS.log_interval)
...@@ -148,54 +158,52 @@ def main(_): ...@@ -148,54 +158,52 @@ def main(_):
checkpoint.save(checkpoint_prefix) checkpoint.save(checkpoint_prefix)
if __name__ == '__main__': class MNISTEagerArgParser(argparse.ArgumentParser):
parser = argparse.ArgumentParser() """Argument parser for running MNIST model with eager trainng loop."""
parser.add_argument( def __init__(self):
'--data_dir', super(MNISTEagerArgParser, self).__init__(parents=[
type=str, parsers.BaseParser(epochs_between_evals=False, multi_gpu=False,
default='/tmp/tensorflow/mnist/input_data', hooks=False),
help='Directory for storing input data') parsers.ImageModelParser()])
parser.add_argument(
'--batch_size', self.add_argument(
type=int, '--log_interval', '-li',
default=100, type=int,
metavar='N', default=10,
help='input batch size for training (default: 100)') metavar='N',
parser.add_argument( help='[default: %(default)s] batches between logging training status')
'--log_interval', self.add_argument(
type=int, '--output_dir', '-od',
default=10, type=str,
metavar='N', default=None,
help='how many batches to wait before logging training status') metavar='<OD>',
parser.add_argument( help='[default: %(default)s] Directory to write TensorBoard summaries')
'--output_dir', self.add_argument(
type=str, '--lr', '-lr',
default=None, type=float,
metavar='N', default=0.01,
help='Directory to write TensorBoard summaries') metavar='<LR>',
parser.add_argument( help='[default: %(default)s] learning rate')
'--checkpoint_dir', self.add_argument(
type=str, '--momentum', '-m',
default='/tmp/tensorflow/mnist/checkpoints/', type=float,
metavar='N', default=0.5,
help='Directory to save checkpoints in (once per epoch)') metavar='<M>',
parser.add_argument( help='[default: %(default)s] SGD momentum')
'--lr', self.add_argument(
type=float, '--no_gpu', '-nogpu',
default=0.01, action='store_true',
metavar='LR', default=False,
help='learning rate (default: 0.01)') help='disables GPU usage even if a GPU is available')
parser.add_argument(
'--momentum', self.set_defaults(
type=float, data_dir='/tmp/tensorflow/mnist/input_data',
default=0.5, model_dir='/tmp/tensorflow/mnist/checkpoints/',
metavar='M', batch_size=100,
help='SGD momentum (default: 0.5)') train_epochs=10,
parser.add_argument( )
'--no_gpu',
action='store_true',
default=False,
help='disables GPU usage even if a GPU is available')
if __name__ == '__main__':
parser = MNISTEagerArgParser()
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
...@@ -216,7 +216,7 @@ def main(argv): ...@@ -216,7 +216,7 @@ def main(argv):
model_dir='/tmp/cifar10_model', model_dir='/tmp/cifar10_model',
resnet_size=32, resnet_size=32,
train_epochs=250, train_epochs=250,
epochs_per_eval=10, epochs_between_evals=10,
batch_size=128) batch_size=128)
flags = parser.parse_args(args=argv[1:]) flags = parser.parse_args(args=argv[1:])
......
...@@ -339,15 +339,16 @@ def resnet_main(flags, model_function, input_function): ...@@ -339,15 +339,16 @@ def resnet_main(flags, model_function, input_function):
'version': flags.version, 'version': flags.version,
}) })
for _ in range(flags.train_epochs // flags.epochs_per_eval): for _ in range(flags.train_epochs // flags.epochs_between_evals):
train_hooks = hooks_helper.get_train_hooks(flags.hooks, batch_size=flags.batch_size) train_hooks = hooks_helper.get_train_hooks(flags.hooks,
batch_size=flags.batch_size)
print('Starting a training cycle.') print('Starting a training cycle.')
def input_fn_train(): def input_fn_train():
return input_function(True, flags.data_dir, flags.batch_size, return input_function(True, flags.data_dir, flags.batch_size,
flags.epochs_per_eval, flags.num_parallel_calls, flags.epochs_between_evals,
flags.multi_gpu) flags.num_parallel_calls, flags.multi_gpu)
classifier.train(input_fn=input_fn_train, hooks=train_hooks, classifier.train(input_fn=input_fn_train, hooks=train_hooks,
max_steps=flags.max_train_steps) max_steps=flags.max_train_steps)
......
...@@ -70,14 +70,14 @@ class BaseParser(argparse.ArgumentParser): ...@@ -70,14 +70,14 @@ class BaseParser(argparse.ArgumentParser):
data_dir: Create a flag for specifying the input data directory. data_dir: Create a flag for specifying the input data directory.
model_dir: Create a flag for specifying the model file directory. model_dir: Create a flag for specifying the model file directory.
train_epochs: Create a flag to specify the number of training epochs. train_epochs: Create a flag to specify the number of training epochs.
epochs_per_eval: Create a flag to specify the frequency of testing. epochs_between_evals: Create a flag to specify the frequency of testing.
batch_size: Create a flag to specify the batch size. batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs. multi_gpu: Create a flag to allow the use of all available GPUs.
hooks: Create a flag to specify hooks for logging. hooks: Create a flag to specify hooks for logging.
""" """
def __init__(self, add_help=False, data_dir=True, model_dir=True, def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_per_eval=True, batch_size=True, train_epochs=True, epochs_between_evals=True, batch_size=True,
multi_gpu=True, hooks=True): multi_gpu=True, hooks=True):
super(BaseParser, self).__init__(add_help=add_help) super(BaseParser, self).__init__(add_help=add_help)
...@@ -91,7 +91,8 @@ class BaseParser(argparse.ArgumentParser): ...@@ -91,7 +91,8 @@ class BaseParser(argparse.ArgumentParser):
if model_dir: if model_dir:
self.add_argument( self.add_argument(
"--model_dir", "-md", default="/tmp", "--model_dir", "-md", default="/tmp",
help="[default: %(default)s] The location of the model files.", help="[default: %(default)s] The location of the model checkpoint "
"files.",
metavar="<MD>", metavar="<MD>",
) )
...@@ -102,12 +103,12 @@ class BaseParser(argparse.ArgumentParser): ...@@ -102,12 +103,12 @@ class BaseParser(argparse.ArgumentParser):
metavar="<TE>" metavar="<TE>"
) )
if epochs_per_eval: if epochs_between_evals:
self.add_argument( self.add_argument(
"--epochs_per_eval", "-epe", type=int, default=1, "--epochs_between_evals", "-ebe", type=int, default=1,
help="[default: %(default)s] The number of training epochs to run " help="[default: %(default)s] The number of training epochs to run "
"between evaluations.", "between evaluations.",
metavar="<EPE>" metavar="<EBE>"
) )
if batch_size: if batch_size:
...@@ -214,6 +215,8 @@ class ImageModelParser(argparse.ArgumentParser): ...@@ -214,6 +215,8 @@ class ImageModelParser(argparse.ArgumentParser):
if data_format: if data_format:
self.add_argument( self.add_argument(
"--data_format", "-df", "--data_format", "-df",
default=None,
choices=['channels_first', 'channels_last'],
help="A flag to override the data format used in the model. " help="A flag to override the data format used in the model. "
"channels_first provides a performance boost on GPU but is not " "channels_first provides a performance boost on GPU but is not "
"always compatible with CPU. If left unspecified, the data " "always compatible with CPU. If left unspecified, the data "
......
...@@ -42,7 +42,7 @@ class BaseTester(unittest.TestCase): ...@@ -42,7 +42,7 @@ class BaseTester(unittest.TestCase):
data_dir="dfgasf", data_dir="dfgasf",
model_dir="dfsdkjgbs", model_dir="dfsdkjgbs",
train_epochs=534, train_epochs=534,
epochs_per_eval=15, epochs_between_evals=15,
batch_size=256, batch_size=256,
hooks=["LoggingTensorHook"], hooks=["LoggingTensorHook"],
num_parallel_calls=18, num_parallel_calls=18,
......
...@@ -63,20 +63,25 @@ def get_train_hooks(name_list, **kwargs): ...@@ -63,20 +63,25 @@ def get_train_hooks(name_list, **kwargs):
return train_hooks return train_hooks
def get_logging_tensor_hook(every_n_iter=100, **kwargs): # pylint: disable=unused-argument def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): # pylint: disable=unused-argument
"""Function to get LoggingTensorHook. """Function to get LoggingTensorHook.
Args: Args:
every_n_iter: `int`, print the values of `tensors` once every N local every_n_iter: `int`, print the values of `tensors` once every N local
steps taken on the current worker. steps taken on the current worker.
tensors_to_log: List of tensor names or dictionary mapping labels to tensor
names. If not set, log _TENSORS_TO_LOG by default.
kwargs: a dictionary of arguments to LoggingTensorHook. kwargs: a dictionary of arguments to LoggingTensorHook.
Returns: Returns:
Returns a LoggingTensorHook with a standard set of tensors that will be Returns a LoggingTensorHook with a standard set of tensors that will be
printed to stdout. printed to stdout.
""" """
if tensors_to_log is None:
tensors_to_log = _TENSORS_TO_LOG
return tf.train.LoggingTensorHook( return tf.train.LoggingTensorHook(
tensors=_TENSORS_TO_LOG, tensors=tensors_to_log,
every_n_iter=every_n_iter) every_n_iter=every_n_iter)
......
...@@ -44,7 +44,7 @@ def run_synthetic(main, tmp_root, extra_flags=None): ...@@ -44,7 +44,7 @@ def run_synthetic(main, tmp_root, extra_flags=None):
model_dir = tempfile.mkdtemp(dir=tmp_root) model_dir = tempfile.mkdtemp(dir=tmp_root)
args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1", args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1",
"--epochs_per_eval", "1", "--use_synthetic_data", "--epochs_between_evals", "1", "--use_synthetic_data",
"--max_train_steps", "1"] + extra_flags "--max_train_steps", "1"] + extra_flags
try: try:
......
...@@ -15,6 +15,8 @@ The input function for the `Estimator` uses `tf.contrib.data.TextLineDataset`, w ...@@ -15,6 +15,8 @@ The input function for the `Estimator` uses `tf.contrib.data.TextLineDataset`, w
The `Estimator` and `Dataset` APIs are both highly encouraged for fast development and efficient training. The `Estimator` and `Dataset` APIs are both highly encouraged for fast development and efficient training.
## Running the code ## Running the code
Make sure to run the command to export the `/models` folder to the python path: https://github.com/tensorflow/models/tree/master/official#running-the-models
### Setup ### Setup
The [Census Income Data Set](https://archive.ics.uci.edu/ml/datasets/Census+Income) that this sample uses for training is hosted by the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/). We have provided a script that downloads and cleans the necessary files. The [Census Income Data Set](https://archive.ics.uci.edu/ml/datasets/Census+Income) that this sample uses for training is hosted by the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/). We have provided a script that downloads and cleans the necessary files.
......
...@@ -18,11 +18,15 @@ from __future__ import division ...@@ -18,11 +18,15 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import os
import shutil import shutil
import sys import sys
import tensorflow as tf import tensorflow as tf
from official.utils.arg_parsers import parsers # pylint: disable=g-bad-import-order
from official.utils.logging import hooks_helper
_CSV_COLUMNS = [ _CSV_COLUMNS = [
'age', 'workclass', 'fnlwgt', 'education', 'education_num', 'age', 'workclass', 'fnlwgt', 'education', 'education_num',
'marital_status', 'occupation', 'relationship', 'race', 'gender', 'marital_status', 'occupation', 'relationship', 'race', 'gender',
...@@ -33,34 +37,6 @@ _CSV_COLUMNS = [ ...@@ -33,34 +37,6 @@ _CSV_COLUMNS = [
_CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''], _CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],
[0], [0], [0], [''], ['']] [0], [0], [0], [''], ['']]
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_dir', type=str, default='/tmp/census_model',
help='Base directory for the model.')
parser.add_argument(
'--model_type', type=str, default='wide_deep',
help="Valid model types: {'wide', 'deep', 'wide_deep'}.")
parser.add_argument(
'--train_epochs', type=int, default=40, help='Number of training epochs.')
parser.add_argument(
'--epochs_per_eval', type=int, default=2,
help='The number of training epochs to run between evaluations.')
parser.add_argument(
'--batch_size', type=int, default=40, help='Number of examples per batch.')
parser.add_argument(
'--train_data', type=str, default='/tmp/census_data/adult.data',
help='Path to the training data.')
parser.add_argument(
'--test_data', type=str, default='/tmp/census_data/adult.test',
help='Path to the test data.')
_NUM_EXAMPLES = { _NUM_EXAMPLES = {
'train': 32561, 'train': 32561,
'validation': 16281, 'validation': 16281,
...@@ -170,8 +146,8 @@ def build_estimator(model_dir, model_type): ...@@ -170,8 +146,8 @@ def build_estimator(model_dir, model_type):
def input_fn(data_file, num_epochs, shuffle, batch_size): def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator.""" """Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), ( assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have either run data_download.py or ' '%s not found. Please make sure you have run data_download.py and '
'set both arguments --train_data and --test_data.' % data_file) 'set the --data_dir argument to the correct path.' % data_file)
def parse_csv(value): def parse_csv(value):
print('Parsing', data_file) print('Parsing', data_file)
...@@ -200,23 +176,51 @@ def main(unused_argv): ...@@ -200,23 +176,51 @@ def main(unused_argv):
shutil.rmtree(FLAGS.model_dir, ignore_errors=True) shutil.rmtree(FLAGS.model_dir, ignore_errors=True)
model = build_estimator(FLAGS.model_dir, FLAGS.model_type) model = build_estimator(FLAGS.model_dir, FLAGS.model_type)
# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs. train_file = os.path.join(FLAGS.data_dir, 'adult.data')
for n in range(FLAGS.train_epochs // FLAGS.epochs_per_eval): test_file = os.path.join(FLAGS.data_dir, 'adult.test')
model.train(input_fn=lambda: input_fn(
FLAGS.train_data, FLAGS.epochs_per_eval, True, FLAGS.batch_size)) train_hooks = hooks_helper.get_train_hooks(
FLAGS.hooks, batch_size=FLAGS.batch_size,
tensors_to_log={'average_loss': 'head/truediv',
'loss': 'head/weighted_loss/Sum'})
# Train and evaluate the model every `FLAGS.epochs_between_evals` epochs.
for n in range(FLAGS.train_epochs // FLAGS.epochs_between_evals):
model.train(
input_fn=lambda: input_fn(train_file, FLAGS.epochs_between_evals, True,
FLAGS.batch_size),
hooks=train_hooks)
results = model.evaluate(input_fn=lambda: input_fn( results = model.evaluate(input_fn=lambda: input_fn(
FLAGS.test_data, 1, False, FLAGS.batch_size)) test_file, 1, False, FLAGS.batch_size))
# Display evaluation metrics # Display evaluation metrics
print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval) print('Results at epoch', (n + 1) * FLAGS.epochs_between_evals)
print('-' * 60) print('-' * 60)
for key in sorted(results): for key in sorted(results):
print('%s: %s' % (key, results[key])) print('%s: %s' % (key, results[key]))
class WideDeepArgParser(argparse.ArgumentParser):
"""Argument parser for running the wide deep model."""
def __init__(self):
super(WideDeepArgParser, self).__init__(parents=[parsers.BaseParser()])
self.add_argument(
'--model_type', '-mt', type=str, default='wide_deep',
choices=['wide', 'deep', 'wide_deep'],
help='[default %(default)s] Valid model types: wide, deep, wide_deep.',
metavar='<MT>')
self.set_defaults(
data_dir='/tmp/census_data',
model_dir='/tmp/census_model',
train_epochs=40,
epochs_between_evals=2,
batch_size=40)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
parser = WideDeepArgParser()
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment