Unverified Commit 5f9f6b84 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Move argparsing from builtin argparse to absl (#4099)

* squash of modular absl usage commits

* delint

* address PR comments

* change hooks to comma separated list, as absl behavior for space separated lists is not as expected
parent 6ec3452c
......@@ -22,12 +22,15 @@ import os
import shutil
import sys
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.arg_parsers import parsers
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
from official.utils.misc import model_helpers
_CSV_COLUMNS = [
'age', 'workclass', 'fnlwgt', 'education', 'education_num',
'marital_status', 'occupation', 'relationship', 'race', 'gender',
......@@ -47,6 +50,24 @@ _NUM_EXAMPLES = {
LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
def define_wide_deep_flags():
"""Add supervised learning flags, as well as wide-deep model type."""
flags_core.define_base()
flags.adopt_module_key_flags(flags_core)
flags.DEFINE_enum(
name="model_type", short_name="mt", default="wide_deep",
enum_values=['wide', 'deep', 'wide_deep'],
help="Select model topology.")
flags_core.set_defaults(data_dir='/tmp/census_data',
model_dir='/tmp/census_model',
train_epochs=40,
epochs_between_evals=2,
batch_size=40)
def build_model_columns():
"""Builds a set of wide and deep feature columns."""
# Continuous columns
......@@ -196,70 +217,50 @@ def export_model(model, model_type, export_dir):
model.export_savedmodel(export_dir, example_input_fn)
def main(argv):
parser = WideDeepArgParser()
flags = parser.parse_args(args=argv[1:])
def main(flags_obj):
# Clean up the model directory if present
shutil.rmtree(flags.model_dir, ignore_errors=True)
model = build_estimator(flags.model_dir, flags.model_type)
shutil.rmtree(flags_obj.model_dir, ignore_errors=True)
model = build_estimator(flags_obj.model_dir, flags_obj.model_type)
train_file = os.path.join(flags.data_dir, 'adult.data')
test_file = os.path.join(flags.data_dir, 'adult.test')
train_file = os.path.join(flags_obj.data_dir, 'adult.data')
test_file = os.path.join(flags_obj.data_dir, 'adult.test')
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
def train_input_fn():
return input_fn(
train_file, flags.epochs_between_evals, True, flags.batch_size)
train_file, flags_obj.epochs_between_evals, True, flags_obj.batch_size)
def eval_input_fn():
return input_fn(test_file, 1, False, flags.batch_size)
return input_fn(test_file, 1, False, flags_obj.batch_size)
loss_prefix = LOSS_PREFIX.get(flags.model_type, '')
loss_prefix = LOSS_PREFIX.get(flags_obj.model_type, '')
train_hooks = hooks_helper.get_train_hooks(
flags.hooks, batch_size=flags.batch_size,
flags_obj.hooks, batch_size=flags_obj.batch_size,
tensors_to_log={'average_loss': loss_prefix + 'head/truediv',
'loss': loss_prefix + 'head/weighted_loss/Sum'})
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
for n in range(flags.train_epochs // flags.epochs_between_evals):
for n in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
model.train(input_fn=train_input_fn, hooks=train_hooks)
results = model.evaluate(input_fn=eval_input_fn)
# Display evaluation metrics
print('Results at epoch', (n + 1) * flags.epochs_between_evals)
print('Results at epoch', (n + 1) * flags_obj.epochs_between_evals)
print('-' * 60)
for key in sorted(results):
print('%s: %s' % (key, results[key]))
if model_helpers.past_stop_threshold(
flags.stop_threshold, results['accuracy']):
flags_obj.stop_threshold, results['accuracy']):
break
# Export the model
if flags.export_dir is not None:
export_model(model, flags.model_type, flags.export_dir)
class WideDeepArgParser(argparse.ArgumentParser):
"""Argument parser for running the wide deep model."""
def __init__(self):
super(WideDeepArgParser, self).__init__(parents=[parsers.BaseParser()])
self.add_argument(
'--model_type', '-mt', type=str, default='wide_deep',
choices=['wide', 'deep', 'wide_deep'],
help='[default %(default)s] Valid model types: wide, deep, wide_deep.',
metavar='<MT>')
self.set_defaults(
data_dir='/tmp/census_data',
model_dir='/tmp/census_model',
train_epochs=40,
epochs_between_evals=2,
batch_size=40)
if flags_obj.export_dir is not None:
export_model(model, flags_obj.model_type, flags_obj.export_dir)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
main(argv=sys.argv)
define_wide_deep_flags()
absl_app.run(main)
......@@ -48,6 +48,11 @@ TEST_CSV = os.path.join(os.path.dirname(__file__), 'wide_deep_test.csv')
class BaseTest(tf.test.TestCase):
"""Tests for Wide Deep model."""
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(BaseTest, cls).setUpClass()
wide_deep.define_wide_deep_flags()
def setUp(self):
# Create temporary CSV file
self.temp_dir = self.get_temp_dir()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment