Commit 4ec3452d authored by Yaroslav Bulatov's avatar Yaroslav Bulatov
Browse files

another xrange change + change to concat_v2

parent 10340bf5
......@@ -18,15 +18,6 @@
import tensorflow as tf
# backward compatible concat (arg order changed in head)
import inspect
def concat(values, axis):
if 'axis' in inspect.signature(tf.concat).parameters.keys():
return tf.concat(values=values, axis=axis)
else:
assert 'concat_dim' in inspect.signature(tf.concat).parameters.keys()
return tf.concat(concat_dim=axis, values=values)
def build_input(dataset, data_path, batch_size, mode):
"""Build CIFAR image and labels.
......@@ -109,7 +100,7 @@ def build_input(dataset, data_path, batch_size, mode):
labels = tf.reshape(labels, [batch_size, 1])
indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1])
labels = tf.sparse_to_dense(
tf.concat(values=[indices, labels], axis=1),
tf.concat_v2(values=[indices, labels], axis=1),
[batch_size, num_classes], 1.0, 0.0)
assert len(images.get_shape()) == 4
......
......@@ -16,6 +16,7 @@
"""ResNet Train/Eval module.
"""
import time
import six
import sys
import cifar_input
......@@ -140,7 +141,7 @@ def evaluate(hps):
saver.restore(sess, ckpt_state.model_checkpoint_path)
total_prediction, correct_prediction = 0, 0
for _ in xrange(FLAGS.eval_batch_count):
for _ in six.moves.range(FLAGS.eval_batch_count):
(summaries, loss, predictions, truth, train_step) = sess.run(
[model.summaries, model.cost, model.predictions,
model.labels, model.global_step])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment