Commit 10340bf5 authored by Yaroslav Bulatov's avatar Yaroslav Bulatov
Browse files

Changes for TF 1.0 compatibility

parent d66941ac
...@@ -18,6 +18,14 @@ ...@@ -18,6 +18,14 @@
import tensorflow as tf import tensorflow as tf
# backward compatible concat (arg order changed in head)
import inspect
def concat(values, axis):
if 'axis' in inspect.signature(tf.concat).parameters.keys():
return tf.concat(values=values, axis=axis)
else:
assert 'concat_dim' in inspect.signature(tf.concat).parameters.keys()
return tf.concat(concat_dim=axis, values=values)
def build_input(dataset, data_path, batch_size, mode): def build_input(dataset, data_path, batch_size, mode):
"""Build CIFAR image and labels. """Build CIFAR image and labels.
...@@ -101,7 +109,7 @@ def build_input(dataset, data_path, batch_size, mode): ...@@ -101,7 +109,7 @@ def build_input(dataset, data_path, batch_size, mode):
labels = tf.reshape(labels, [batch_size, 1]) labels = tf.reshape(labels, [batch_size, 1])
indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1]) indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1])
labels = tf.sparse_to_dense( labels = tf.sparse_to_dense(
tf.concat(1, [indices, labels]), tf.concat(values=[indices, labels], axis=1),
[batch_size, num_classes], 1.0, 0.0) [batch_size, num_classes], 1.0, 0.0)
assert len(images.get_shape()) == 4 assert len(images.get_shape()) == 4
......
...@@ -24,6 +24,7 @@ from collections import namedtuple ...@@ -24,6 +24,7 @@ from collections import namedtuple
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import six
from tensorflow.python.training import moving_averages from tensorflow.python.training import moving_averages
...@@ -89,21 +90,21 @@ class ResNet(object): ...@@ -89,21 +90,21 @@ class ResNet(object):
with tf.variable_scope('unit_1_0'): with tf.variable_scope('unit_1_0'):
x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]), x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]),
activate_before_residual[0]) activate_before_residual[0])
for i in xrange(1, self.hps.num_residual_units): for i in six.moves.range(1, self.hps.num_residual_units):
with tf.variable_scope('unit_1_%d' % i): with tf.variable_scope('unit_1_%d' % i):
x = res_func(x, filters[1], filters[1], self._stride_arr(1), False) x = res_func(x, filters[1], filters[1], self._stride_arr(1), False)
with tf.variable_scope('unit_2_0'): with tf.variable_scope('unit_2_0'):
x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]), x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]),
activate_before_residual[1]) activate_before_residual[1])
for i in xrange(1, self.hps.num_residual_units): for i in six.moves.range(1, self.hps.num_residual_units):
with tf.variable_scope('unit_2_%d' % i): with tf.variable_scope('unit_2_%d' % i):
x = res_func(x, filters[2], filters[2], self._stride_arr(1), False) x = res_func(x, filters[2], filters[2], self._stride_arr(1), False)
with tf.variable_scope('unit_3_0'): with tf.variable_scope('unit_3_0'):
x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]), x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]),
activate_before_residual[2]) activate_before_residual[2])
for i in xrange(1, self.hps.num_residual_units): for i in six.moves.range(1, self.hps.num_residual_units):
with tf.variable_scope('unit_3_%d' % i): with tf.variable_scope('unit_3_%d' % i):
x = res_func(x, filters[3], filters[3], self._stride_arr(1), False) x = res_func(x, filters[3], filters[3], self._stride_arr(1), False)
...@@ -118,7 +119,7 @@ class ResNet(object): ...@@ -118,7 +119,7 @@ class ResNet(object):
with tf.variable_scope('costs'): with tf.variable_scope('costs'):
xent = tf.nn.softmax_cross_entropy_with_logits( xent = tf.nn.softmax_cross_entropy_with_logits(
logits, self.labels) logits=logits, labels=self.labels)
self.cost = tf.reduce_mean(xent, name='xent') self.cost = tf.reduce_mean(xent, name='xent')
self.cost += self._decay() self.cost += self._decay()
...@@ -266,7 +267,7 @@ class ResNet(object): ...@@ -266,7 +267,7 @@ class ResNet(object):
costs.append(tf.nn.l2_loss(var)) costs.append(tf.nn.l2_loss(var))
# tf.histogram_summary(var.op.name, var) # tf.histogram_summary(var.op.name, var)
return tf.mul(self.hps.weight_decay_rate, tf.add_n(costs)) return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs))
def _conv(self, name, x, filter_size, in_filters, out_filters, strides): def _conv(self, name, x, filter_size, in_filters, out_filters, strides):
"""Convolution.""" """Convolution."""
...@@ -280,7 +281,7 @@ class ResNet(object): ...@@ -280,7 +281,7 @@ class ResNet(object):
def _relu(self, x, leakiness=0.0): def _relu(self, x, leakiness=0.0):
"""Relu, with optional leaky support.""" """Relu, with optional leaky support."""
return tf.select(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu') return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
def _fully_connected(self, x, out_dim): def _fully_connected(self, x, out_dim):
"""FullyConnected layer for final output.""" """FullyConnected layer for final output."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment