Commit 6a2b8dd2 authored by cclauss's avatar cclauss Committed by GitHub
Browse files

center() is defined in utils

Also define xrange() for Python 3.  See #2105
parent 424f8da6
......@@ -26,6 +26,11 @@ import tensorflow.contrib.slim as slim
from tensorflow.python.ops import init_ops
import utils as U
try:
xrange # Python 2
except NameError:
xrange = range # Python 3
FLAGS = tf.flags.FLAGS
Q_COLLECTION = "q_collection"
......@@ -293,7 +298,7 @@ class SBN(object): # REINFORCE
logQHard = tf.add_n(logQHard)
# REINFORCE
learning_signal = tf.stop_gradient(center(reinforce_learning_signal))
learning_signal = tf.stop_gradient(U.center(reinforce_learning_signal))
self.optimizerLoss = -(learning_signal*logQHard +
reinforce_model_grad)
self.lHat = map(tf.reduce_mean, [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment