Commit aae0a947 authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #2116 from cclauss/patch-3

rebar: center() is defined in utils
parents 55b440f3 1d5dba69
...@@ -26,6 +26,11 @@ import tensorflow.contrib.slim as slim ...@@ -26,6 +26,11 @@ import tensorflow.contrib.slim as slim
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
import utils as U import utils as U
try:
xrange # Python 2
except NameError:
xrange = range # Python 3
FLAGS = tf.flags.FLAGS FLAGS = tf.flags.FLAGS
Q_COLLECTION = "q_collection" Q_COLLECTION = "q_collection"
...@@ -293,7 +298,7 @@ class SBN(object): # REINFORCE ...@@ -293,7 +298,7 @@ class SBN(object): # REINFORCE
logQHard = tf.add_n(logQHard) logQHard = tf.add_n(logQHard)
# REINFORCE # REINFORCE
learning_signal = tf.stop_gradient(center(reinforce_learning_signal)) learning_signal = tf.stop_gradient(U.center(reinforce_learning_signal))
self.optimizerLoss = -(learning_signal*logQHard + self.optimizerLoss = -(learning_signal*logQHard +
reinforce_model_grad) reinforce_model_grad)
self.lHat = map(tf.reduce_mean, [ self.lHat = map(tf.reduce_mean, [
......
...@@ -28,6 +28,12 @@ import tensorflow as tf ...@@ -28,6 +28,12 @@ import tensorflow as tf
import rebar import rebar
import datasets import datasets
import logger as L import logger as L
try:
xrange # Python 2
except NameError:
xrange = range # Python 3
gfile = tf.gfile gfile = tf.gfile
tf.app.flags.DEFINE_string("working_dir", "/tmp/rebar", tf.app.flags.DEFINE_string("working_dir", "/tmp/rebar",
......
...@@ -134,6 +134,3 @@ def logSumExp(t, axis=0, keep_dims = False): ...@@ -134,6 +134,3 @@ def logSumExp(t, axis=0, keep_dims = False):
return tf.expand_dims(res, axis) return tf.expand_dims(res, axis)
else: else:
return res return res
if __name__ == '__main__':
app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment