Commit 4cfb259f authored by XinyueZ's avatar XinyueZ
Browse files

Removed from_dataset(), let make_dataset() return input_fn as provider of dataset

parent 7a34628e
...@@ -110,11 +110,10 @@ def load_data(y_name="price", train_fraction=0.7, seed=None): ...@@ -110,11 +110,10 @@ def load_data(y_name="price", train_fraction=0.7, seed=None):
return (x_train, y_train), (x_test, y_test) return (x_train, y_train), (x_test, y_test)
def from_dataset(dataset): return lambda: dataset.make_one_shot_iterator().get_next()
def make_dataset(batch_sz, x, y=None, shuffle=False, shuffle_buffer_size=1000): def make_dataset(batch_sz, x, y=None, shuffle=False, shuffle_buffer_size=1000):
"""Create a slice Dataset from a pandas DataFrame and labels""" """Create a slice Dataset from a pandas DataFrame and labels"""
def input_fn():
if y is not None: if y is not None:
dataset = tf.data.Dataset.from_tensor_slices((dict(x), y)) dataset = tf.data.Dataset.from_tensor_slices((dict(x), y))
else: else:
...@@ -123,4 +122,6 @@ def make_dataset(batch_sz, x, y=None, shuffle=False, shuffle_buffer_size=1000): ...@@ -123,4 +122,6 @@ def make_dataset(batch_sz, x, y=None, shuffle=False, shuffle_buffer_size=1000):
dataset = dataset.shuffle(shuffle_buffer_size).batch(batch_sz).repeat() dataset = dataset.shuffle(shuffle_buffer_size).batch(batch_sz).repeat()
else: else:
dataset = dataset.batch(batch_sz) dataset = dataset.batch(batch_sz)
return dataset return dataset.make_one_shot_iterator().get_next()
return input_fn
...@@ -101,11 +101,11 @@ def main(argv): ...@@ -101,11 +101,11 @@ def main(argv):
train_y /= args.price_norm_factor train_y /= args.price_norm_factor
test_y /= args.price_norm_factor test_y /= args.price_norm_factor
# Build the training dataset. # Provide the training input dataset.
train = automobile_data.make_dataset(args.batch_size, train_x, train_y, True, 1000) train_input_fn = automobile_data.make_dataset(args.batch_size, train_x, train_y, True, 1000)
# Build the validation dataset. # Build the validation dataset.
test = automobile_data.make_dataset(args.batch_size, test_x, test_y) test_input_fn = automobile_data.make_dataset(args.batch_size, test_x, test_y)
# The first way assigns a unique weight to each category. To do this you must # The first way assigns a unique weight to each category. To do this you must
# specify the category's vocabulary (values outside this specification will # specify the category's vocabulary (values outside this specification will
...@@ -144,10 +144,10 @@ def main(argv): ...@@ -144,10 +144,10 @@ def main(argv):
}) })
# Train the model. # Train the model.
model.train(input_fn=automobile_data.from_dataset(train), steps=args.train_steps) model.train(input_fn=train_input_fn, steps=args.train_steps)
# Evaluate how the model performs on data it has not yet seen. # Evaluate how the model performs on data it has not yet seen.
eval_result = model.evaluate(input_fn=automobile_data.from_dataset(test)) eval_result = model.evaluate(input_fn=test_input_fn)
# Print the Root Mean Square Error (RMSE). # Print the Root Mean Square Error (RMSE).
print("\n" + 80 * "*") print("\n" + 80 * "*")
......
...@@ -41,11 +41,11 @@ def main(argv): ...@@ -41,11 +41,11 @@ def main(argv):
train_y /= args.price_norm_factor train_y /= args.price_norm_factor
test_y /= args.price_norm_factor test_y /= args.price_norm_factor
# Build the training dataset. # Provide the training input dataset.
train = automobile_data.make_dataset(args.batch_size, train_x, train_y, True, 1000) train_input_fn = automobile_data.make_dataset(args.batch_size, train_x, train_y, True, 1000)
# Build the validation dataset. # Provide the validation input dataset.
test = automobile_data.make_dataset(args.batch_size, test_x, test_y) test_input_fn = automobile_data.make_dataset(args.batch_size, test_x, test_y)
# Use the same categorical columns as in `linear_regression_categorical` # Use the same categorical columns as in `linear_regression_categorical`
body_style_vocab = ["hardtop", "wagon", "sedan", "hatchback", "convertible"] body_style_vocab = ["hardtop", "wagon", "sedan", "hatchback", "convertible"]
...@@ -74,10 +74,10 @@ def main(argv): ...@@ -74,10 +74,10 @@ def main(argv):
# Train the model. # Train the model.
# By default, the Estimators log output every 100 steps. # By default, the Estimators log output every 100 steps.
model.train(input_fn=automobile_data.from_dataset(train), steps=args.train_steps) model.train(input_fn=train_input_fn, steps=args.train_steps)
# Evaluate how the model performs on data it has not yet seen. # Evaluate how the model performs on data it has not yet seen.
eval_result = model.evaluate(input_fn=automobile_data.from_dataset(test)) eval_result = model.evaluate(input_fn=test_input_fn)
# The evaluation returns a Python dictionary. The "average_loss" key holds the # The evaluation returns a Python dictionary. The "average_loss" key holds the
# Mean Squared Error (MSE). # Mean Squared Error (MSE).
......
...@@ -42,11 +42,11 @@ def main(argv): ...@@ -42,11 +42,11 @@ def main(argv):
train_y /= args.price_norm_factor train_y /= args.price_norm_factor
test_y /= args.price_norm_factor test_y /= args.price_norm_factor
# Build the training dataset. # Provide the training input dataset.
train = automobile_data.make_dataset(args.batch_size, train_x, train_y, True, 1000) train_input_fn = automobile_data.make_dataset(args.batch_size, train_x, train_y, True, 1000)
# Build the validation dataset. # Provide the validation input dataset.
test = automobile_data.make_dataset(args.batch_size, test_x, test_y) test_input_fn = automobile_data.make_dataset(args.batch_size, test_x, test_y)
feature_columns = [ feature_columns = [
# "curb-weight" and "highway-mpg" are numeric columns. # "curb-weight" and "highway-mpg" are numeric columns.
...@@ -59,10 +59,10 @@ def main(argv): ...@@ -59,10 +59,10 @@ def main(argv):
# Train the model. # Train the model.
# By default, the Estimators log output every 100 steps. # By default, the Estimators log output every 100 steps.
model.train(input_fn=automobile_data.from_dataset(train), steps=args.train_steps) model.train(input_fn=train_input_fn, steps=args.train_steps)
# Evaluate how the model performs on data it has not yet seen. # Evaluate how the model performs on data it has not yet seen.
eval_result = model.evaluate(input_fn=automobile_data.from_dataset(test)) eval_result = model.evaluate(input_fn=test_input_fn)
# The evaluation returns a Python dictionary. The "average_loss" key holds the # The evaluation returns a Python dictionary. The "average_loss" key holds the
# Mean Squared Error (MSE). # Mean Squared Error (MSE).
...@@ -79,8 +79,9 @@ def main(argv): ...@@ -79,8 +79,9 @@ def main(argv):
"highway-mpg": np.array([30, 40]) "highway-mpg": np.array([30, 40])
} }
predict = automobile_data.make_dataset(1, input_dict) # Provide the predict input dataset.
predict_results = model.predict(input_fn=automobile_data.from_dataset(predict)) predict_input_fn = automobile_data.make_dataset(1, input_dict)
predict_results = model.predict(input_fn=predict_input_fn)
# Print the prediction results. # Print the prediction results.
print("\nPrediction results:") print("\nPrediction results:")
......
...@@ -41,11 +41,11 @@ def main(argv): ...@@ -41,11 +41,11 @@ def main(argv):
train_y /= args.price_norm_factor train_y /= args.price_norm_factor
test_y /= args.price_norm_factor test_y /= args.price_norm_factor
# Build the training dataset. # Provide the training input dataset.
train = automobile_data.make_dataset(args.batch_size, train_x, train_y, True, 1000) train_input_fn = automobile_data.make_dataset(args.batch_size, train_x, train_y, True, 1000)
# Build the validation dataset. # Provide the validation input dataset.
test = automobile_data.make_dataset(args.batch_size, test_x, test_y) test_input_fn = automobile_data.make_dataset(args.batch_size, test_x, test_y)
# The following code demonstrates two of the ways that `feature_columns` can # The following code demonstrates two of the ways that `feature_columns` can
# be used to build a model with categorical inputs. # be used to build a model with categorical inputs.
...@@ -83,10 +83,10 @@ def main(argv): ...@@ -83,10 +83,10 @@ def main(argv):
# Train the model. # Train the model.
# By default, the Estimators log output every 100 steps. # By default, the Estimators log output every 100 steps.
model.train(input_fn=automobile_data.from_dataset(train), steps=args.train_steps) model.train(input_fn=train_input_fn, steps=args.train_steps)
# Evaluate how the model performs on data it has not yet seen. # Evaluate how the model performs on data it has not yet seen.
eval_result = model.evaluate(input_fn=automobile_data.from_dataset(test)) eval_result = model.evaluate(input_fn=test_input_fn)
# The evaluation returns a Python dictionary. The "average_loss" key holds the # The evaluation returns a Python dictionary. The "average_loss" key holds the
# Mean Squared Error (MSE). # Mean Squared Error (MSE).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment