"vscode:/vscode.git/clone" did not exist on "4ad0d4d6d3bd9b3507644d872fb27113fd3e487b"
Commit 065180e6 authored by Billy Lamberta's avatar Billy Lamberta
Browse files

Merge branch 'master' of github.com:tensorflow/models into getting-started

parents 3381b321 16af6582
......@@ -26,7 +26,7 @@ import sys
import tempfile
def run_synthetic(main, tmp_root, extra_flags=None):
def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
"""Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A
......@@ -37,6 +37,8 @@ def run_synthetic(main, tmp_root, extra_flags=None):
function is "<MODULE>.main(argv)".
tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the caller of this function.
synth: Use synthetic data.
max_train: Maximum number of allowed training steps.
"""
extra_flags = [] if extra_flags is None else extra_flags
......@@ -44,8 +46,13 @@ def run_synthetic(main, tmp_root, extra_flags=None):
model_dir = tempfile.mkdtemp(dir=tmp_root)
args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1",
"--epochs_between_evals", "1", "--use_synthetic_data",
"--max_train_steps", "1"] + extra_flags
"--epochs_between_evals", "1"] + extra_flags
if synth:
args.append("--use_synthetic_data")
if max_train is not None:
args.extend(["--max_train_steps", str(max_train)])
try:
main(args)
......
......@@ -43,6 +43,9 @@ _NUM_EXAMPLES = {
}
LOSS_PREFIX = {'wide': 'linear/', 'deep': 'dnn/'}
def build_model_columns():
"""Builds a set of wide and deep feature columns."""
# Continuous columns
......@@ -190,10 +193,11 @@ def main(argv):
def eval_input_fn():
return input_fn(test_file, 1, False, flags.batch_size)
loss_prefix = LOSS_PREFIX.get(flags.model_type, '')
train_hooks = hooks_helper.get_train_hooks(
flags.hooks, batch_size=flags.batch_size,
tensors_to_log={'average_loss': 'head/truediv',
'loss': 'head/weighted_loss/Sum'})
tensors_to_log={'average_loss': loss_prefix + 'head/truediv',
'loss': loss_prefix + 'head/weighted_loss/Sum'})
# Train and evaluate the model every `flags.epochs_between_evals` epochs.
for n in range(flags.train_epochs // flags.epochs_between_evals):
......
......@@ -21,6 +21,7 @@ import os
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.testing import integration
from official.wide_deep import wide_deep
tf.logging.set_verbosity(tf.logging.ERROR)
......@@ -54,6 +55,14 @@ class BaseTest(tf.test.TestCase):
with tf.gfile.Open(self.input_csv, 'w') as temp_csv:
temp_csv.write(TEST_INPUT)
with tf.gfile.Open(TEST_CSV, "r") as temp_csv:
test_csv_contents = temp_csv.read()
# Used for end-to-end tests.
for fname in ['adult.data', 'adult.test']:
with tf.gfile.Open(os.path.join(self.temp_dir, fname), 'w') as test_csv:
test_csv.write(test_csv_contents)
def test_input_fn(self):
dataset = wide_deep.input_fn(self.input_csv, 1, False, 1)
features, labels = dataset.make_one_shot_iterator().get_next()
......@@ -107,6 +116,30 @@ class BaseTest(tf.test.TestCase):
def test_wide_deep_estimator_training(self):
self.build_and_test_estimator('wide_deep')
def test_end_to_end_wide(self):
integration.run_synthetic(
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[
'--data_dir', self.get_temp_dir(),
'--model_type', 'wide',
],
synth=False, max_train=None)
def test_end_to_end_deep(self):
integration.run_synthetic(
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[
'--data_dir', self.get_temp_dir(),
'--model_type', 'deep',
],
synth=False, max_train=None)
def test_end_to_end_wide_deep(self):
integration.run_synthetic(
main=wide_deep.main, tmp_root=self.get_temp_dir(), extra_flags=[
'--data_dir', self.get_temp_dir(),
'--model_type', 'wide_deep',
],
synth=False, max_train=None)
if __name__ == '__main__':
tf.test.main()
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment