Commit 6c875697 authored by ispirmustafa's avatar ispirmustafa Committed by GitHub
Browse files

Merge pull request #727 from ispirmustafa/master

Convert resnet model to use monitored_session
parents a533325c 931c70a1
...@@ -73,7 +73,7 @@ def build_input(dataset, data_path, batch_size, mode): ...@@ -73,7 +73,7 @@ def build_input(dataset, data_path, batch_size, mode):
# image = tf.image.random_brightness(image, max_delta=63. / 255.) # image = tf.image.random_brightness(image, max_delta=63. / 255.)
# image = tf.image.random_saturation(image, lower=0.5, upper=1.5) # image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
# image = tf.image.random_contrast(image, lower=0.2, upper=1.8) # image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
image = tf.image.per_image_whitening(image) image = tf.image.per_image_standardization(image)
example_queue = tf.RandomShuffleQueue( example_queue = tf.RandomShuffleQueue(
capacity=16 * batch_size, capacity=16 * batch_size,
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
"""ResNet Train/Eval module. """ResNet Train/Eval module.
""" """
import sys
import time import time
import sys
import cifar_input import cifar_input
import numpy as np import numpy as np
...@@ -26,8 +26,10 @@ import tensorflow as tf ...@@ -26,8 +26,10 @@ import tensorflow as tf
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.') tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.')
tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.') tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.')
tf.app.flags.DEFINE_string('train_data_path', '', 'Filepattern for training data.') tf.app.flags.DEFINE_string('train_data_path', '',
tf.app.flags.DEFINE_string('eval_data_path', '', 'Filepattern for eval data') 'Filepattern for training data.')
tf.app.flags.DEFINE_string('eval_data_path', '',
'Filepattern for eval data')
tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.') tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.')
tf.app.flags.DEFINE_string('train_dir', '', tf.app.flags.DEFINE_string('train_dir', '',
'Directory to keep training outputs.') 'Directory to keep training outputs.')
...@@ -50,50 +52,65 @@ def train(hps): ...@@ -50,50 +52,65 @@ def train(hps):
FLAGS.dataset, FLAGS.train_data_path, hps.batch_size, FLAGS.mode) FLAGS.dataset, FLAGS.train_data_path, hps.batch_size, FLAGS.mode)
model = resnet_model.ResNet(hps, images, labels, FLAGS.mode) model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
model.build_graph() model.build_graph()
summary_writer = tf.train.SummaryWriter(FLAGS.train_dir)
param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
sv = tf.train.Supervisor(logdir=FLAGS.log_root, tf.get_default_graph(),
is_chief=True, tfprof_options=tf.contrib.tfprof.model_analyzer.
summary_op=None, TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
save_summaries_secs=60, sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
save_model_secs=300,
global_step=model.global_step) tf.contrib.tfprof.model_analyzer.print_model_analysis(
sess = sv.prepare_or_wait_for_session( tf.get_default_graph(),
config=tf.ConfigProto(allow_soft_placement=True)) tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
step = 0 truth = tf.argmax(model.labels, axis=1)
lrn_rate = 0.1 predictions = tf.argmax(model.predictions, axis=1)
precision = tf.reduce_mean(tf.to_float(tf.equal(predictions, truth)))
while not sv.should_stop():
(_, summaries, loss, predictions, truth, train_step) = sess.run( summary_hook = tf.train.SummarySaverHook(
[model.train_op, model.summaries, model.cost, model.predictions, save_steps=100,
model.labels, model.global_step], output_dir=FLAGS.train_dir,
feed_dict={model.lrn_rate: lrn_rate}) summary_op=[model.summaries,
tf.summary.scalar('Precision', precision)])
if train_step < 40000:
lrn_rate = 0.1 logging_hook = tf.train.LoggingTensorHook(
elif train_step < 60000: tensors={'step': model.global_step,
lrn_rate = 0.01 'loss': model.cost,
elif train_step < 80000: 'precision': precision},
lrn_rate = 0.001 every_n_iter=100)
else:
lrn_rate = 0.0001 class _LearningRateSetterHook(tf.train.SessionRunHook):
"""Sets learning_rate based on global step."""
truth = np.argmax(truth, axis=1)
predictions = np.argmax(predictions, axis=1) def begin(self):
precision = np.mean(truth == predictions) self._lrn_rate = 0.1
step += 1 def before_run(self, run_context):
if step % 100 == 0: return tf.train.SessionRunArgs(
precision_summ = tf.Summary() model.global_step, # Asks for global step value.
precision_summ.value.add( feed_dict={model.lrn_rate: self._lrn_rate}) # Sets learning rate
tag='Precision', simple_value=precision)
summary_writer.add_summary(precision_summ, train_step) def after_run(self, run_context, run_values):
summary_writer.add_summary(summaries, train_step) train_step = run_values.results
tf.logging.info('loss: %.3f, precision: %.3f\n' % (loss, precision)) if train_step < 40000:
summary_writer.flush() self._lrn_rate = 0.1
elif train_step < 60000:
sv.Stop() self._lrn_rate = 0.01
elif train_step < 80000:
self._lrn_rate = 0.001
else:
self._lrn_rate = 0.0001
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.log_root,
hooks=[logging_hook, _LearningRateSetterHook()],
chief_only_hooks=[summary_hook],
# Since we provide a SummarySaverHook, we need to disable default
# SummarySaverHook. To do that we set save_summaries_steps to 0.
save_summaries_steps=0,
config=tf.ConfigProto(allow_soft_placement=True)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(model.train_op)
def evaluate(hps): def evaluate(hps):
...@@ -103,7 +120,7 @@ def evaluate(hps): ...@@ -103,7 +120,7 @@ def evaluate(hps):
model = resnet_model.ResNet(hps, images, labels, FLAGS.mode) model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
model.build_graph() model.build_graph()
saver = tf.train.Saver() saver = tf.train.Saver()
summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir) summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
tf.train.start_queue_runners(sess) tf.train.start_queue_runners(sess)
......
...@@ -55,7 +55,7 @@ class ResNet(object): ...@@ -55,7 +55,7 @@ class ResNet(object):
def build_graph(self): def build_graph(self):
"""Build a whole graph for the model.""" """Build a whole graph for the model."""
self.global_step = tf.Variable(0, name='global_step', trainable=False) self.global_step = tf.contrib.framework.get_or_create_global_step()
self._build_model() self._build_model()
if self.mode == 'train': if self.mode == 'train':
self._build_train_op() self._build_train_op()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment