Unverified Commit ac6ab360 authored by Mark Daoust's avatar Mark Daoust Committed by GitHub
Browse files

Merge pull request #3260 from lc0/dataset

Fix a typo. Python function name
parents 3f78f4cf f8e854b5
...@@ -38,7 +38,7 @@ URL_TRAIN = "http://download.tensorflow.org/data/iris_training.csv" ...@@ -38,7 +38,7 @@ URL_TRAIN = "http://download.tensorflow.org/data/iris_training.csv"
URL_TEST = "http://download.tensorflow.org/data/iris_test.csv" URL_TEST = "http://download.tensorflow.org/data/iris_test.csv"
def downloadDataset(url, file): def download_dataset(url, file):
if not os.path.exists(PATH_DATASET): if not os.path.exists(PATH_DATASET):
os.makedirs(PATH_DATASET) os.makedirs(PATH_DATASET)
if not os.path.exists(file): if not os.path.exists(file):
...@@ -46,8 +46,8 @@ def downloadDataset(url, file): ...@@ -46,8 +46,8 @@ def downloadDataset(url, file):
with open(file, "wb") as f: with open(file, "wb") as f:
f.write(data) f.write(data)
f.close() f.close()
downloadDataset(URL_TRAIN, FILE_TRAIN) download_dataset(URL_TRAIN, FILE_TRAIN)
downloadDataset(URL_TEST, FILE_TEST) download_dataset(URL_TEST, FILE_TEST)
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
...@@ -97,7 +97,7 @@ classifier = tf.estimator.DNNClassifier( ...@@ -97,7 +97,7 @@ classifier = tf.estimator.DNNClassifier(
n_classes=3, n_classes=3,
model_dir=PATH) # Path to where checkpoints etc are stored model_dir=PATH) # Path to where checkpoints etc are stored
# Train our model, use the previously function my_input_fn # Train our model, use the previously defined function my_input_fn
# Input to training is a file with training example # Input to training is a file with training example
# Stop training after 8 iterations of train data (epochs) # Stop training after 8 iterations of train data (epochs)
classifier.train( classifier.train(
...@@ -127,6 +127,7 @@ prediction_input = [[5.9, 3.0, 4.2, 1.5], # -> 1, Iris Versicolor ...@@ -127,6 +127,7 @@ prediction_input = [[5.9, 3.0, 4.2, 1.5], # -> 1, Iris Versicolor
[6.9, 3.1, 5.4, 2.1], # -> 2, Iris Virginica [6.9, 3.1, 5.4, 2.1], # -> 2, Iris Virginica
[5.1, 3.3, 1.7, 0.5]] # -> 0, Iris Sentosa [5.1, 3.3, 1.7, 0.5]] # -> 0, Iris Sentosa
def new_input_fn(): def new_input_fn():
def decode(x): def decode(x):
x = tf.split(x, 4) # Need to split into our 4 features x = tf.split(x, 4) # Need to split into our 4 features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment