Unverified Commit 0f5803bd authored by Mark Daoust's avatar Mark Daoust Committed by GitHub
Browse files

Merge pull request #3028 from tensorflow/mhyttsten-patch-2

Update blog_custom_estimators.py
parents d19587e4 688143ed
......@@ -14,8 +14,7 @@
# ==============================================================================
# This is the complete code for the following blogpost:
# https://developers.googleblog.com/2017/09/introducing-tensorflow-datasets.html
# (https://goo.gl/Ujm2Ep)
# https://developers.googleblog.com/2017/12/creating-custom-estimators-in-tensorflow.html
import tensorflow as tf
import os
......@@ -116,7 +115,7 @@ def my_model_fn(
h2 = tf.layers.Dense(10, activation=tf.nn.relu)(h1)
# Output 'logits' layer is three number = probability distribution
# between Iris Sentosa, Versicolor, and Viginica
# between Iris Setosa, Versicolor, and Viginica
logits = tf.layers.Dense(3)(h2)
# class_ids will be the model prediction for the class (Iris flower type)
......@@ -206,14 +205,14 @@ predict_results = classifier.predict(
tf.logging.info("Prediction on test file")
for prediction in predict_results:
# Will print the predicted class, i.e: 0, 1, or 2 if the prediction
# is Iris Sentosa, Vericolor, Virginica, respectively.
# is Iris Setosa, Vericolor, Virginica, respectively.
tf.logging.info("...{}".format(prediction["class_ids"]))
# Let create a dataset for prediction
# We've taken the first 3 examples in FILE_TEST
prediction_input = [[5.9, 3.0, 4.2, 1.5], # -> 1, Iris Versicolor
[6.9, 3.1, 5.4, 2.1], # -> 2, Iris Virginica
[5.1, 3.3, 1.7, 0.5]] # -> 0, Iris Sentosa
[5.1, 3.3, 1.7, 0.5]] # -> 0, Iris Setosa
def new_input_fn():
def decode(x):
......@@ -234,7 +233,7 @@ tf.logging.info("Predictions on memory")
for idx, prediction in enumerate(predict_results):
type = prediction["class_ids"] # Get the predicted class (index)
if type == 0:
tf.logging.info("...I think: {}, is Iris Sentosa".format(prediction_input[idx]))
tf.logging.info("...I think: {}, is Iris Setosa".format(prediction_input[idx]))
elif type == 1:
tf.logging.info("...I think: {}, is Iris Versicolor".format(prediction_input[idx]))
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment