Commit aae0a947 authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #2116 from cclauss/patch-3

rebar: center() is defined in utils
parents 55b440f3 1d5dba69
......@@ -26,6 +26,11 @@ import tensorflow.contrib.slim as slim
from tensorflow.python.ops import init_ops
import utils as U
try:
xrange # Python 2
except NameError:
xrange = range # Python 3
FLAGS = tf.flags.FLAGS
Q_COLLECTION = "q_collection"
......@@ -293,7 +298,7 @@ class SBN(object): # REINFORCE
logQHard = tf.add_n(logQHard)
# REINFORCE
learning_signal = tf.stop_gradient(center(reinforce_learning_signal))
learning_signal = tf.stop_gradient(U.center(reinforce_learning_signal))
self.optimizerLoss = -(learning_signal*logQHard +
reinforce_model_grad)
self.lHat = map(tf.reduce_mean, [
......
......@@ -28,6 +28,12 @@ import tensorflow as tf
import rebar
import datasets
import logger as L
try:
xrange # Python 2
except NameError:
xrange = range # Python 3
gfile = tf.gfile
tf.app.flags.DEFINE_string("working_dir", "/tmp/rebar",
......
......@@ -134,6 +134,3 @@ def logSumExp(t, axis=0, keep_dims = False):
return tf.expand_dims(res, axis)
else:
return res
if __name__ == '__main__':
app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment