Commit adc27172 authored by George K's avatar George K Committed by Toby Boyd
Browse files

readme source url and import optimization (#7084)

* restored missing function

* missing import

* missing imports

* updated tutorial link

* recovered _print_download_progress func

* change default train with float16 instead of float32 accuracy

* test disable func call

* redundant function call, currently data is pulled automatically in

* optimized imports

* optimized imports
parent 47a59023
......@@ -10,4 +10,4 @@ Code in this directory demonstrates how to use TensorFlow to train and evaluate
Detailed instructions on how to get started available at:
http://tensorflow.org/tutorials/deep_cnn/
https://www.tensorflow.org/tutorials/images/deep_cnn
......@@ -46,7 +46,7 @@ FLAGS = tf.app.flags.FLAGS
# Basic model parameters.
tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_boolean('use_fp16', False,
tf.app.flags.DEFINE_boolean('use_fp16', True,
"""Train the model using fp16.""")
# Global constants describing the CIFAR-10 data set.
......@@ -146,7 +146,6 @@ def distorted_inputs():
def inputs(eval_data):
"""Construct input for CIFAR evaluation using the Reader ops.
Args:
eval_data: bool, indicating if one should use the train or eval data set.
......@@ -154,8 +153,7 @@ def inputs(eval_data):
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size.
"""
images, labels = cifar10_input.inputs(eval_data=eval_data,
batch_size=FLAGS.batch_size)
images, labels = cifar10_input.inputs(eval_data=eval_data, batch_size=FLAGS.batch_size)
if FLAGS.use_fp16:
images = tf.cast(images, tf.float16)
labels = tf.cast(labels, tf.float16)
......
......@@ -39,14 +39,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import os.path
import re
import time
from datetime import datetime
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from six.moves import xrange # pylint: disable=redefined-builtin
import cifar10
FLAGS = tf.app.flags.FLAGS
......@@ -266,7 +267,6 @@ def train():
def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment