Unverified Commit 0f5803bd authored by Mark Daoust's avatar Mark Daoust Committed by GitHub
Browse files

Merge pull request #3028 from tensorflow/mhyttsten-patch-2

Update blog_custom_estimators.py
parents d19587e4 688143ed
...@@ -14,8 +14,7 @@ ...@@ -14,8 +14,7 @@
# ============================================================================== # ==============================================================================
# This is the complete code for the following blogpost: # This is the complete code for the following blogpost:
# https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html # https://developers.googleblog.com/2017/12/creating-custom-estimators-in-tensorflow.html
# (https://goo.gl/Ujm2Ep)
import tensorflow as tf import tensorflow as tf
import os import os
...@@ -116,7 +115,7 @@ def my_model_fn( ...@@ -116,7 +115,7 @@ def my_model_fn(
h2 = tf.layers.Dense(10, activation=tf.nn.relu)(h1) h2 = tf.layers.Dense(10, activation=tf.nn.relu)(h1)
# Output 'logits' layer is three number = probability distribution # Output 'logits' layer is three number = probability distribution
# between Iris Sentosa, Versicolor, and Viginica # between Iris Setosa, Versicolor, and Viginica
logits = tf.layers.Dense(3)(h2) logits = tf.layers.Dense(3)(h2)
# class_ids will be the model prediction for the class (Iris flower type) # class_ids will be the model prediction for the class (Iris flower type)
...@@ -206,14 +205,14 @@ predict_results = classifier.predict( ...@@ -206,14 +205,14 @@ predict_results = classifier.predict(
tf.logging.info("Prediction on test file") tf.logging.info("Prediction on test file")
for prediction in predict_results: for prediction in predict_results:
# Will print the predicted class, i.e: 0, 1, or 2 if the prediction # Will print the predicted class, i.e: 0, 1, or 2 if the prediction
# is Iris Sentosa, Vericolor, Virginica, respectively. # is Iris Setosa, Vericolor, Virginica, respectively.
tf.logging.info("...{}".format(prediction["class_ids"])) tf.logging.info("...{}".format(prediction["class_ids"]))
# Let create a dataset for prediction # Let create a dataset for prediction
# We've taken the first 3 examples in FILE_TEST # We've taken the first 3 examples in FILE_TEST
prediction_input = [[5.9, 3.0, 4.2, 1.5], # -> 1, Iris Versicolor prediction_input = [[5.9, 3.0, 4.2, 1.5], # -> 1, Iris Versicolor
[6.9, 3.1, 5.4, 2.1], # -> 2, Iris Virginica [6.9, 3.1, 5.4, 2.1], # -> 2, Iris Virginica
[5.1, 3.3, 1.7, 0.5]] # -> 0, Iris Sentosa [5.1, 3.3, 1.7, 0.5]] # -> 0, Iris Setosa
def new_input_fn(): def new_input_fn():
def decode(x): def decode(x):
...@@ -234,7 +233,7 @@ tf.logging.info("Predictions on memory") ...@@ -234,7 +233,7 @@ tf.logging.info("Predictions on memory")
for idx, prediction in enumerate(predict_results): for idx, prediction in enumerate(predict_results):
type = prediction["class_ids"] # Get the predicted class (index) type = prediction["class_ids"] # Get the predicted class (index)
if type == 0: if type == 0:
tf.logging.info("...I think: {}, is Iris Sentosa".format(prediction_input[idx])) tf.logging.info("...I think: {}, is Iris Setosa".format(prediction_input[idx]))
elif type == 1: elif type == 1:
tf.logging.info("...I think: {}, is Iris Versicolor".format(prediction_input[idx])) tf.logging.info("...I think: {}, is Iris Versicolor".format(prediction_input[idx]))
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment