# Copyright (c) Microsoft Corporation # All rights reserved. # # MIT License # # Permission is hereby granted, free of charge, # to any person obtaining a copy of this software and associated # documentation files (the "Software"), to deal in the Software without restriction, # including without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and # to permit persons to whom the Software is furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included # in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING # BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """Builds the CIFAR-10 network. Summary of available functions: # Compute input images and labels for training. If you would like to run # evaluations, use inputs() instead. inputs, labels = distorted_inputs() # Compute inference on the model inputs to make a prediction. predictions = inference(inputs) # Compute the total loss of the prediction with respect to the labels. loss = loss(predictions, labels) # Create a graph to run one step of training with respect to the loss. train_op = train(loss, global_step) """ import argparse import os import re import tarfile import urllib import tensorflow as tf import cifar10_input parser = argparse.ArgumentParser() # Basic model parameters. parser.add_argument('--batch_size', type=int, default=512, help='Number of images to process in a batch.') parser.add_argument('--data_dir', type=str, default='/tmp/cifar10_data', help='Path to the CIFAR-10 data directory.') parser.add_argument('--use_fp16', type=bool, default=False, help='Train the model using fp16.') FLAGS = parser.parse_args() # Global constants describing the CIFAR-10 data set. IMAGE_SIZE = cifar10_input.IMAGE_SIZE NUM_CLASSES = cifar10_input.NUM_CLASSES NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL # Constants describing the training process. MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average. NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays. LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor. INITIAL_LEARNING_RATE = 0.1 # Initial learning rate. # If a model is trained with multiple GPUs, prefix all Op names with tower_name # to differentiate the operations. Note that this prefix is removed from the # names of the summaries when visualizing a model. TOWER_NAME = 'tower' DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' def _activation_summary(x_input): """Helper to create summaries for activations. Creates a summary that provides a histogram of activations. Creates a summary that measures the sparsity of activations. Args: x_input: Tensor Returns: nothing """ # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training # session. This helps the clarity of presentation on tensorboard. tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x_input.op.name) tf.summary.histogram(tensor_name + '/activations', x_input) tf.summary.scalar(tensor_name + '/sparsity', tf.nn.zero_fraction(x_input)) def _variable_on_cpu(name, shape, initializer): """Helper to create a Variable stored on CPU memory. Args: name: name of the variable shape: list of ints initializer: initializer for Variable Returns: Variable Tensor """ with tf.device('/cpu:0'): dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 var = tf.get_variable( name, shape, initializer=initializer, dtype=dtype) return var def _variable_with_weight_decay(name, shape, stddev, l2loss_wd): """Helper to create an initialized Variable with weight decay. Note that the Variable is initialized with a truncated normal distribution. A weight decay is added only if one is specified. Args: name: name of the variable shape: list of ints stddev: standard deviation of a truncated Gaussian l2loss_wd: add L2Loss weight decay multiplied by this float. If None, weight decay is not added for this Variable. Returns: Variable Tensor """ dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 var = _variable_on_cpu(name, shape, tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)) if l2loss_wd is not None: weight_decay = tf.multiply(tf.nn.l2_loss(var), l2loss_wd, name='weight_loss') tf.add_to_collection('losses', weight_decay) return var def distorted_inputs(): """Construct distorted input for CIFAR training using the Reader ops. Returns: images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [batch_size] size. Raises: ValueError: If no data_dir """ if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') FLAGS.data_dir = './' print(FLAGS.data_dir) data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') images, labels = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=FLAGS.batch_size) if FLAGS.use_fp16: images = tf.cast(images, tf.float16) labels = tf.cast(labels, tf.float16) return images, labels def inputs(eval_data): """Construct input for CIFAR evaluation using the Reader ops. Args: eval_data: bool, indicating if one should use the train or eval data set. Returns: images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. labels: Labels. 1D tensor of [batch_size] size. Raises: ValueError: If no data_dir """ FLAGS.data_dir = './' if not FLAGS.data_dir: raise ValueError('Please supply a data_dir') data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') if eval_data is None: images, labels = cifar10_input.distorted_inputs(data_dir=data_dir, batch_size=FLAGS.batch_size) else: images, labels = cifar10_input.inputs(eval_data=eval_data, data_dir=data_dir, batch_size=FLAGS.batch_size) if FLAGS.use_fp16: images = tf.cast(images, tf.float16) labels = tf.cast(labels, tf.float16) return images, labels def maybe_download_and_extract(): """ Download and extract the tarball from Alex's website. """ FLAGS.data_dir = './' dest_directory = FLAGS.data_dir print(dest_directory) if not os.path.exists(dest_directory): os.makedirs(dest_directory) filename = DATA_URL.split('/')[-1] filepath = os.path.join(dest_directory, filename) if not os.path.exists(filepath): def _progress(count, block_size, total_size): print('\r>> Downloading %s %.1f%%' % (filename, float( count * block_size) / float(total_size) * 100.0)) filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) print() statinfo = os.stat(filepath) print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') extracted_dir_path = os.path.join(dest_directory, 'cifar-10-batches-bin') if not os.path.exists(extracted_dir_path): tarfile.open(filepath, 'r:gz').extractall(dest_directory)