Commit adc27172 authored by George K's avatar George K Committed by Toby Boyd
Browse files

readme source url and import optimization (#7084)

* restored missing function

* missing import

* missing imports

* updated tutorial link

* recovered _print_download_progress func

* change default train with float16 instead of float32 accuracy

* test disable func call

* redundant function call, currently data is pulled automatically in

* optimized imports

* optimized imports
parent 47a59023
...@@ -10,4 +10,4 @@ Code in this directory demonstrates how to use TensorFlow to train and evaluate ...@@ -10,4 +10,4 @@ Code in this directory demonstrates how to use TensorFlow to train and evaluate
Detailed instructions on how to get started available at: Detailed instructions on how to get started available at:
http://tensorflow.org/tutorials/deep_cnn/ https://www.tensorflow.org/tutorials/images/deep_cnn
...@@ -46,7 +46,7 @@ FLAGS = tf.app.flags.FLAGS ...@@ -46,7 +46,7 @@ FLAGS = tf.app.flags.FLAGS
# Basic model parameters. # Basic model parameters.
tf.app.flags.DEFINE_integer('batch_size', 128, tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""") """Number of images to process in a batch.""")
tf.app.flags.DEFINE_boolean('use_fp16', False, tf.app.flags.DEFINE_boolean('use_fp16', True,
"""Train the model using fp16.""") """Train the model using fp16.""")
# Global constants describing the CIFAR-10 data set. # Global constants describing the CIFAR-10 data set.
...@@ -146,7 +146,6 @@ def distorted_inputs(): ...@@ -146,7 +146,6 @@ def distorted_inputs():
def inputs(eval_data): def inputs(eval_data):
"""Construct input for CIFAR evaluation using the Reader ops. """Construct input for CIFAR evaluation using the Reader ops.
Args: Args:
eval_data: bool, indicating if one should use the train or eval data set. eval_data: bool, indicating if one should use the train or eval data set.
...@@ -154,8 +153,7 @@ def inputs(eval_data): ...@@ -154,8 +153,7 @@ def inputs(eval_data):
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
labels: Labels. 1D tensor of [batch_size] size. labels: Labels. 1D tensor of [batch_size] size.
""" """
images, labels = cifar10_input.inputs(eval_data=eval_data, images, labels = cifar10_input.inputs(eval_data=eval_data, batch_size=FLAGS.batch_size)
batch_size=FLAGS.batch_size)
if FLAGS.use_fp16: if FLAGS.use_fp16:
images = tf.cast(images, tf.float16) images = tf.cast(images, tf.float16)
labels = tf.cast(labels, tf.float16) labels = tf.cast(labels, tf.float16)
......
...@@ -39,14 +39,15 @@ from __future__ import absolute_import ...@@ -39,14 +39,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from datetime import datetime
import os.path import os.path
import re import re
import time import time
from datetime import datetime
import numpy as np import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from six.moves import xrange # pylint: disable=redefined-builtin
import cifar10 import cifar10
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
...@@ -266,7 +267,6 @@ def train(): ...@@ -266,7 +267,6 @@ def train():
def main(argv=None): # pylint: disable=unused-argument def main(argv=None): # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir): if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir) tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir) tf.gfile.MakeDirs(FLAGS.train_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment