Commit 065180e6 authored by Billy Lamberta's avatar Billy Lamberta
Browse files

Merge branch 'master' of github.com:tensorflow/models into getting-started

parents 3381b321 16af6582
...@@ -26,7 +26,7 @@ import sys ...@@ -26,7 +26,7 @@ import sys
import tempfile import tempfile
def run_synthetic(main, tmp_root, extra_flags=None): def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
"""Performs a minimal run of a model. """Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A This function is intended to test for syntax errors throughout a model. A
...@@ -37,6 +37,8 @@ def run_synthetic(main, tmp_root, extra_flags=None): ...@@ -37,6 +37,8 @@ def run_synthetic(main, tmp_root, extra_flags=None):
function is "<MODULE>.main(argv)". function is "<MODULE>.main(argv)".
tmp_root: Root path for the temp directory created by the test class. tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the caller of this function. extra_flags: Additional flags passed by the caller of this function.
synth: Use synthetic data.
max_train: Maximum number of allowed training steps.
""" """
extra_flags = [] if extra_flags is None else extra_flags extra_flags = [] if extra_flags is None else extra_flags
...@@ -44,8 +46,13 @@ def run_synthetic(main, tmp_root, extra_flags=None): ...@@ -44,8 +46,13 @@ def run_synthetic(main, tmp_root, extra_flags=None):
model_dir = tempfile.mkdtemp(dir=tmp_root) model_dir = tempfile.mkdtemp(dir=tmp_root)
args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1", args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1",
"--epochs_between_evals", "1", "--use_synthetic_data", "--epochs_between_evals", "1"] + extra_flags
"--max_train_steps", "1"] + extra_flags
if synth:
args.append("--use_synthetic_data")
if max_train is not None:
args.extend(["--max_train_steps", str(max_train)])
try: try:
main(args) main(args)
......
...@@ -43,6 +43,9 @@ _NUM_EXAMPLES = { ...@@ -43,6 +43,9 @@ _NUM_EXAMPLES = {
} }
LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
def build_model_columns(): def build_model_columns():
"""Builds a set of wide and deep feature columns.""" """Builds a set of wide and deep feature columns."""
# Continuous columns # Continuous columns
...@@ -190,10 +193,11 @@ def main(argv): ...@@ -190,10 +193,11 @@ def main(argv):
def eval_input_fn(): def eval_input_fn():
return input_fn(test_file, 1, False, flags.batch_size) return input_fn(test_file, 1, False, flags.batch_size)
loss_prefix = LOSS_PREFIX.get(flags.model_type, '')
train_hooks = hooks_helper.get_train_hooks( train_hooks = hooks_helper.get_train_hooks(
flags.hooks, batch_size=flags.batch_size, flags.hooks, batch_size=flags.batch_size,
tensors_to_log={'average_loss': 'head/truediv', tensors_to_log={'average_loss': loss_prefix + 'head/truediv',
'loss': 'head/weighted_loss/Sum'}) 'loss': loss_prefix + 'head/weighted_loss/Sum'})
# Train and evaluate the model every `flags.epochs_between_evals` epochs. # Train and evaluate the model every `flags.epochs_between_evals` epochs.
for n in range(flags.train_epochs // flags.epochs_between_evals): for n in range(flags.train_epochs // flags.epochs_between_evals):
......
...@@ -21,6 +21,7 @@ import os ...@@ -21,6 +21,7 @@ import os
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.testing import integration
from official.wide_deep import wide_deep from official.wide_deep import wide_deep
tf.logging.set_verbosity(tf.logging.ERROR) tf.logging.set_verbosity(tf.logging.ERROR)
...@@ -54,6 +55,14 @@ class BaseTest(tf.test.TestCase): ...@@ -54,6 +55,14 @@ class BaseTest(tf.test.TestCase):
with tf.gfile.Open(self.input_csv, 'w') as temp_csv: with tf.gfile.Open(self.input_csv, 'w') as temp_csv:
temp_csv.write(TEST_INPUT) temp_csv.write(TEST_INPUT)
with tf.gfile.Open(TEST_CSV, "r") as temp_csv:
test_csv_contents = temp_csv.read()
# Used for end-to-end tests.
for fname in ['adult.data', 'adult.test']:
with tf.gfile.Open(os.path.join(self.temp_dir, fname), 'w') as test_csv:
test_csv.write(test_csv_contents)
def test_input_fn(self): def test_input_fn(self):
dataset = wide_deep.input_fn(self.input_csv, 1, False, 1) dataset = wide_deep.input_fn(self.input_csv, 1, False, 1)
features, labels = dataset.make_one_shot_iterator().get_next() features, labels = dataset.make_one_shot_iterator().get_next()
...@@ -107,6 +116,30 @@ class BaseTest(tf.test.TestCase): ...@@ -107,6 +116,30 @@ class BaseTest(tf.test.TestCase):
def test_wide_deep_estimator_training(self): def test_wide_deep_estimator_training(self):
self.build_and_test_estimator('wide_deep') self.build_and_test_estimator('wide_deep')
def test_end_to_end_wide(self):
integration.run_synthetic(
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[
'--data_dir', self.get_temp_dir(),
'--model_type', 'wide',
],
synth=False, max_train=None)
def test_end_to_end_deep(self):
integration.run_synthetic(
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[
'--data_dir', self.get_temp_dir(),
'--model_type', 'deep',
],
synth=False, max_train=None)
def test_end_to_end_wide_deep(self):
integration.run_synthetic(
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[
'--data_dir', self.get_temp_dir(),
'--model_type', 'wide_deep',
],
synth=False, max_train=None)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment