Commit 931c70a1 authored by Mustafa Ispir's avatar Mustafa Ispir Committed by Mustafa Ispir
Browse files

Convert resnet model to use monitored_session

parent a533325c
......@@ -73,7 +73,7 @@ def build_input(dataset, data_path, batch_size, mode):
# image = tf.image.random_brightness(image, max_delta=63. / 255.)
# image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
# image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
image = tf.image.per_image_whitening(image)
image = tf.image.per_image_standardization(image)
example_queue = tf.RandomShuffleQueue(
capacity=16 * batch_size,
......
......@@ -15,8 +15,8 @@
"""ResNet Train/Eval module.
"""
import sys
import time
import sys
import cifar_input
import numpy as np
......@@ -26,8 +26,10 @@ import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.')
tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.')
tf.app.flags.DEFINE_string('train_data_path', '', 'Filepattern for training data.')
tf.app.flags.DEFINE_string('eval_data_path', '', 'Filepattern for eval data')
tf.app.flags.DEFINE_string('train_data_path', '',
'Filepattern for training data.')
tf.app.flags.DEFINE_string('eval_data_path', '',
'Filepattern for eval data')
tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.')
tf.app.flags.DEFINE_string('train_dir', '',
'Directory to keep training outputs.')
......@@ -50,50 +52,65 @@ def train(hps):
FLAGS.dataset, FLAGS.train_data_path, hps.batch_size, FLAGS.mode)
model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
model.build_graph()
summary_writer = tf.train.SummaryWriter(FLAGS.train_dir)
sv = tf.train.Supervisor(logdir=FLAGS.log_root,
is_chief=True,
summary_op=None,
save_summaries_secs=60,
save_model_secs=300,
global_step=model.global_step)
sess = sv.prepare_or_wait_for_session(
config=tf.ConfigProto(allow_soft_placement=True))
step = 0
lrn_rate = 0.1
while not sv.should_stop():
(_, summaries, loss, predictions, truth, train_step) = sess.run(
[model.train_op, model.summaries, model.cost, model.predictions,
model.labels, model.global_step],
feed_dict={model.lrn_rate: lrn_rate})
if train_step < 40000:
lrn_rate = 0.1
elif train_step < 60000:
lrn_rate = 0.01
elif train_step < 80000:
lrn_rate = 0.001
else:
lrn_rate = 0.0001
truth = np.argmax(truth, axis=1)
predictions = np.argmax(predictions, axis=1)
precision = np.mean(truth == predictions)
step += 1
if step % 100 == 0:
precision_summ = tf.Summary()
precision_summ.value.add(
tag='Precision', simple_value=precision)
summary_writer.add_summary(precision_summ, train_step)
summary_writer.add_summary(summaries, train_step)
tf.logging.info('loss: %.3f, precision: %.3f\n' % (loss, precision))
summary_writer.flush()
sv.Stop()
param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
tf.get_default_graph(),
tfprof_options=tf.contrib.tfprof.model_analyzer.
TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
tf.contrib.tfprof.model_analyzer.print_model_analysis(
tf.get_default_graph(),
tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
truth = tf.argmax(model.labels, axis=1)
predictions = tf.argmax(model.predictions, axis=1)
precision = tf.reduce_mean(tf.to_float(tf.equal(predictions, truth)))
summary_hook = tf.train.SummarySaverHook(
save_steps=100,
output_dir=FLAGS.train_dir,
summary_op=[model.summaries,
tf.summary.scalar('Precision', precision)])
logging_hook = tf.train.LoggingTensorHook(
tensors={'step': model.global_step,
'loss': model.cost,
'precision': precision},
every_n_iter=100)
class _LearningRateSetterHook(tf.train.SessionRunHook):
"""Sets learning_rate based on global step."""
def begin(self):
self._lrn_rate = 0.1
def before_run(self, run_context):
return tf.train.SessionRunArgs(
model.global_step, # Asks for global step value.
feed_dict={model.lrn_rate: self._lrn_rate}) # Sets learning rate
def after_run(self, run_context, run_values):
train_step = run_values.results
if train_step < 40000:
self._lrn_rate = 0.1
elif train_step < 60000:
self._lrn_rate = 0.01
elif train_step < 80000:
self._lrn_rate = 0.001
else:
self._lrn_rate = 0.0001
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.log_root,
hooks=[logging_hook, _LearningRateSetterHook()],
chief_only_hooks=[summary_hook],
# Since we provide a SummarySaverHook, we need to disable default
# SummarySaverHook. To do that we set save_summaries_steps to 0.
save_summaries_steps=0,
config=tf.ConfigProto(allow_soft_placement=True)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(model.train_op)
def evaluate(hps):
......@@ -103,7 +120,7 @@ def evaluate(hps):
model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
model.build_graph()
saver = tf.train.Saver()
summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir)
summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
tf.train.start_queue_runners(sess)
......
......@@ -55,7 +55,7 @@ class ResNet(object):
def build_graph(self):
"""Build a whole graph for the model."""
self.global_step = tf.Variable(0, name='global_step', trainable=False)
self.global_step = tf.contrib.framework.get_or_create_global_step()
self._build_model()
if self.mode == 'train':
self._build_train_op()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment