Commit d34cf181 authored by Olivia's avatar Olivia
Browse files

sort of works...still a namespace problem in _train

parent 16454bdf
......@@ -40,21 +40,20 @@ import re
import sys
import tarfile
import argparse
from six.moves import urllib
import tensorflow as tf
import cifar10_input
FLAGS = tf.app.flags.FLAGS
parser = argparse.ArgumentParser()
# Basic model parameters.
tf.app.flags.DEFINE_integer('batch_size', 128,
"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
"""Path to the CIFAR-10 data directory.""")
tf.app.flags.DEFINE_boolean('use_fp16', False,
"""Train the model using fp16.""")
parser.add_argument('--batch_size', type=int, default=128, help='Number of images to process in a batch.')
parser.add_argument('--data_dir', type=str, default='/tmp/cifar10_data', help='Path to the CIFAR-10 data directory.')
parser.add_argument('--use_fp16', type=bool, default=False, help='Train the model using fp16.')
FLAGS = parser.parse_args()
# Global constants describing the CIFAR-10 data set.
IMAGE_SIZE = cifar10_input.IMAGE_SIZE
NUM_CLASSES = cifar10_input.NUM_CLASSES
......
......@@ -39,6 +39,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from datetime import datetime
import os.path
import re
......@@ -49,17 +50,15 @@ from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
import cifar10
FLAGS = tf.app.flags.FLAGS
parser = argparse.ArgumentParser()
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 1000000,
"""Number of batches to run.""")
tf.app.flags.DEFINE_integer('num_gpus', 1,
"""How many GPUs to use.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""")
parser.add_argument('--train_dir', type=str, default='/tmp/cifar10_train', help='Directory where to write event logs and checkpoint.')
parser.add_argument('--max_steps', type=int, default=1000000, help='Number of batches to run.')
parser.add_argument('--num_gpus', type=int, default=1, help='How many GPUs to use.')
parser.add_argument('--log_device_placement', type=bool, default=False, help='Whether to log device placement.')
def tower_loss(scope, images, labels):
......@@ -274,4 +273,5 @@ def main(argv=None): # pylint: disable=unused-argument
if __name__ == '__main__':
FLAGS = parser.parse_args()
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment