Commit 538f89c4 authored by Neal Wu's avatar Neal Wu Committed by GitHub
Browse files

Fix wide_deep_test's TEST_CSV path and only test the wide_deep model for the sake of time (#2566)

parent 9683ee99
......@@ -41,8 +41,7 @@ TEST_INPUT_VALUES = {
'occupation': 'abc',
}
TEST_TRAINING_CSV = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'wide_deep_test.csv')
TEST_CSV = os.path.join(os.path.dirname(__file__), 'wide_deep_test.csv')
class BaseTest(tf.test.TestCase):
......@@ -80,31 +79,25 @@ class BaseTest(tf.test.TestCase):
# Train for 1 step to initialize model and evaluate initial loss
model.train(
input_fn=wide_deep.input_fn(
TEST_TRAINING_CSV, num_epochs=1, shuffle=True, batch_size=1),
TEST_CSV, num_epochs=1, shuffle=True, batch_size=1),
steps=1)
initial_results = model.evaluate(
input_fn=wide_deep.input_fn(
TEST_TRAINING_CSV, num_epochs=1, shuffle=False, batch_size=1))
TEST_CSV, num_epochs=1, shuffle=False, batch_size=1))
# Train for 40 steps at batch size 2 and evaluate final loss
model.train(
input_fn=wide_deep.input_fn(
TEST_TRAINING_CSV, num_epochs=None, shuffle=True, batch_size=2),
TEST_CSV, num_epochs=None, shuffle=True, batch_size=2),
steps=40)
final_results = model.evaluate(
input_fn=wide_deep.input_fn(
TEST_TRAINING_CSV, num_epochs=1, shuffle=False, batch_size=1))
TEST_CSV, num_epochs=1, shuffle=False, batch_size=1))
print('%s initial results:' % model_type, initial_results)
print('%s final results:' % model_type, final_results)
self.assertLess(final_results['loss'], initial_results['loss'])
def test_deep_estimator_training(self):
self.build_and_test_estimator('deep')
def test_wide_estimator_training(self):
self.build_and_test_estimator('wide')
def test_wide_deep_estimator_training(self):
self.build_and_test_estimator('wide_deep')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment