"examples/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "709b4439bead016bef8f3af4b68cfc3c6429af33"
Commit 22036b6f authored by Wills Cui's avatar Wills Cui Committed by drpngx
Browse files

Update resnet to run with tf r0.12 API. (#833)

* Update resnet to run with tf r0.12 API.
1. tf.image.per_image_whitening -> tf.image.per_image_standardization
2. Use tf.summary to replace tf.image_summary, tf.scalar_summary, tf.merge_all_summaries.

* remove log
parent f88eef94
...@@ -84,7 +84,7 @@ def build_input(dataset, data_path, batch_size, mode): ...@@ -84,7 +84,7 @@ def build_input(dataset, data_path, batch_size, mode):
else: else:
image = tf.image.resize_image_with_crop_or_pad( image = tf.image.resize_image_with_crop_or_pad(
image, image_size, image_size) image, image_size, image_size)
image = tf.image.per_image_whitening(image) image = tf.image.per_image_standardization(image)
example_queue = tf.FIFOQueue( example_queue = tf.FIFOQueue(
3 * batch_size, 3 * batch_size,
...@@ -112,5 +112,5 @@ def build_input(dataset, data_path, batch_size, mode): ...@@ -112,5 +112,5 @@ def build_input(dataset, data_path, batch_size, mode):
assert labels.get_shape()[1] == num_classes assert labels.get_shape()[1] == num_classes
# Display the training images in the visualizer. # Display the training images in the visualizer.
tf.image_summary('images', images) tf.summary.image('images', images)
return images, labels return images, labels
...@@ -70,8 +70,8 @@ def train(hps): ...@@ -70,8 +70,8 @@ def train(hps):
summary_hook = tf.train.SummarySaverHook( summary_hook = tf.train.SummarySaverHook(
save_steps=100, save_steps=100,
output_dir=FLAGS.train_dir, output_dir=FLAGS.train_dir,
summary_op=[model.summaries, summary_op=tf.summary.merge([model.summaries,
tf.summary.scalar('Precision', precision)]) tf.summary.scalar('Precision', precision)]))
logging_hook = tf.train.LoggingTensorHook( logging_hook = tf.train.LoggingTensorHook(
tensors={'step': model.global_step, tensors={'step': model.global_step,
......
...@@ -59,7 +59,7 @@ class ResNet(object): ...@@ -59,7 +59,7 @@ class ResNet(object):
self._build_model() self._build_model()
if self.mode == 'train': if self.mode == 'train':
self._build_train_op() self._build_train_op()
self.summaries = tf.merge_all_summaries() self.summaries = tf.summary.merge_all()
def _stride_arr(self, stride): def _stride_arr(self, stride):
"""Map a stride scalar to the stride array for tf.nn.conv2d.""" """Map a stride scalar to the stride array for tf.nn.conv2d."""
...@@ -122,12 +122,12 @@ class ResNet(object): ...@@ -122,12 +122,12 @@ class ResNet(object):
self.cost = tf.reduce_mean(xent, name='xent') self.cost = tf.reduce_mean(xent, name='xent')
self.cost += self._decay() self.cost += self._decay()
tf.scalar_summary('cost', self.cost) tf.summary.scalar('cost', self.cost)
def _build_train_op(self): def _build_train_op(self):
"""Build training specific ops for the graph.""" """Build training specific ops for the graph."""
self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32) self.lrn_rate = tf.constant(self.hps.lrn_rate, tf.float32)
tf.scalar_summary('learning rate', self.lrn_rate) tf.summary.scalar('learning rate', self.lrn_rate)
trainable_variables = tf.trainable_variables() trainable_variables = tf.trainable_variables()
grads = tf.gradients(self.cost, trainable_variables) grads = tf.gradients(self.cost, trainable_variables)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment