Commit 1453d070 authored by MTDzi's avatar MTDzi
Browse files

Fix to Memory.query + other small changes in Learning to Remember Rare Events

parent e029542a
......@@ -173,6 +173,21 @@ class Memory(object):
softmax_temp = max(1.0, np.log(0.2 * self.choose_k) / self.alpha)
mask = tf.nn.softmax(hint_pool_sims[:, :choose_k - 1] * softmax_temp)
# prepare returned values
nearest_neighbor = tf.to_int32(
tf.argmax(hint_pool_sims[:, :choose_k - 1], 1))
no_teacher_idxs = tf.gather(
tf.reshape(hint_pool_idxs, [-1]),
nearest_neighbor + choose_k * tf.range(batch_size))
with tf.device(self.var_cache_device):
result = tf.gather(self.mem_vals, tf.reshape(no_teacher_idxs, [-1]))
if not output_given:
teacher_loss = None
return result, mask, teacher_loss
# prepare hints from the teacher on hint pool
teacher_hints = tf.to_float(
tf.abs(tf.expand_dims(intended_output, 1) - hint_pool_mem_vals))
......@@ -192,13 +207,6 @@ class Memory(object):
teacher_vals *= (
1 - tf.to_float(tf.equal(0.0, tf.reduce_sum(teacher_hints, 1))))
# prepare returned values
nearest_neighbor = tf.to_int32(
tf.argmax(hint_pool_sims[:, :choose_k - 1], 1))
no_teacher_idxs = tf.gather(
tf.reshape(hint_pool_idxs, [-1]),
nearest_neighbor + choose_k * tf.range(batch_size))
# we'll determine whether to do an update to memory based on whether
# memory was queried correctly
sliced_hints = tf.slice(teacher_hints, [0, 0], [-1, self.correct_in_top])
......@@ -208,9 +216,6 @@ class Memory(object):
teacher_loss = (tf.nn.relu(neg_teacher_vals - teacher_vals + self.alpha)
- self.alpha)
with tf.device(self.var_cache_device):
result = tf.gather(self.mem_vals, tf.reshape(no_teacher_idxs, [-1]))
# prepare memory updates
update_keys = normalized_query
update_vals = intended_output
......
......@@ -178,27 +178,13 @@ class Model(object):
self.x, self.y = self.get_xy_placeholders()
# This context creates variables
with tf.variable_scope('core', reuse=None):
self.loss, self.gradient_ops = self.train(self.x, self.y)
# And this one re-uses them (thus the `reuse=True`)
with tf.variable_scope('core', reuse=True):
self.y_preds = self.eval(self.x, self.y)
# setup memory "reset" ops
(self.mem_keys, self.mem_vals,
self.mem_age, self.recent_idx) = self.memory.get()
self.mem_keys_reset = tf.placeholder(self.mem_keys.dtype,
tf.identity(self.mem_keys).shape)
self.mem_vals_reset = tf.placeholder(self.mem_vals.dtype,
tf.identity(self.mem_vals).shape)
self.mem_age_reset = tf.placeholder(self.mem_age.dtype,
tf.identity(self.mem_age).shape)
self.recent_idx_reset = tf.placeholder(self.recent_idx.dtype,
tf.identity(self.recent_idx).shape)
self.mem_reset_op = self.memory.set(self.mem_keys_reset,
self.mem_vals_reset,
self.mem_age_reset,
None)
def training_ops(self, loss):
opt = self.get_optimizer()
params = tf.trainable_variables()
......@@ -254,8 +240,14 @@ class Model(object):
Predicted y.
"""
cur_memory = sess.run([self.mem_keys, self.mem_vals,
self.mem_age])
# Storing current memory state to restore it after prediction
mem_keys, mem_vals, mem_age, _ = self.memory.get()
cur_memory = (
tf.identity(mem_keys),
tf.identity(mem_vals),
tf.identity(mem_age),
None,
)
outputs = [self.y_preds]
if y is None:
......@@ -263,10 +255,8 @@ class Model(object):
else:
ret = sess.run(outputs, feed_dict={self.x: x, self.y: y})
sess.run([self.mem_reset_op],
feed_dict={self.mem_keys_reset: cur_memory[0],
self.mem_vals_reset: cur_memory[1],
self.mem_age_reset: cur_memory[2]})
# Restoring memory state
self.memory.set(*cur_memory)
return ret
......@@ -284,8 +274,14 @@ class Model(object):
List of predicted y.
"""
cur_memory = sess.run([self.mem_keys, self.mem_vals,
self.mem_age])
# Storing current memory state to restore it after prediction
mem_keys, mem_vals, mem_age, _ = self.memory.get()
cur_memory = (
tf.identity(mem_keys),
tf.identity(mem_vals),
tf.identity(mem_age),
None,
)
if clear_memory:
self.clear_memory(sess)
......@@ -297,10 +293,8 @@ class Model(object):
y_pred = out[0]
y_preds.append(y_pred)
sess.run([self.mem_reset_op],
feed_dict={self.mem_keys_reset: cur_memory[0],
self.mem_vals_reset: cur_memory[1],
self.mem_age_reset: cur_memory[2]})
# Restoring memory state
self.memory.set(*cur_memory)
return y_preds
......
......@@ -208,17 +208,16 @@ class Trainer(object):
correct.append(self.compute_correct(np.array(y), y_preds))
# compute per-shot accuracies
seen_counts = [[0] * episode_width for _ in xrange(batch_size)]
seen_counts = [0] * episode_width
# loop over episode steps
for yy, yy_preds in zip(y, y_preds):
# loop over batch examples
for k, (yyy, yyy_preds) in enumerate(zip(yy, yy_preds)):
yyy, yyy_preds = int(yyy), int(yyy_preds)
count = seen_counts[k][yyy % episode_width]
if count in correct_by_shot:
correct_by_shot[count].append(
self.individual_compute_correct(yyy, yyy_preds))
seen_counts[k][yyy % episode_width] = count + 1
yyy, yyy_preds = int(yy[0]), int(yy_preds[0])
count = seen_counts[yyy % episode_width]
if count in correct_by_shot:
correct_by_shot[count].append(
self.individual_compute_correct(yyy, yyy_preds))
seen_counts[yyy % episode_width] = count + 1
logging.info('validation overall accuracy %f', np.mean(correct))
logging.info('%d-shot: %.3f, ' * num_shots,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment