Commit 1453d070 authored by MTDzi's avatar MTDzi
Browse files

Fix to Memory.query + other small changes in Learning to Remember Rare Events

parent e029542a
...@@ -173,6 +173,21 @@ class Memory(object): ...@@ -173,6 +173,21 @@ class Memory(object):
softmax_temp = max(1.0, np.log(0.2 * self.choose_k) / self.alpha) softmax_temp = max(1.0, np.log(0.2 * self.choose_k) / self.alpha)
mask = tf.nn.softmax(hint_pool_sims[:, :choose_k - 1] * softmax_temp) mask = tf.nn.softmax(hint_pool_sims[:, :choose_k - 1] * softmax_temp)
# prepare returned values
nearest_neighbor = tf.to_int32(
tf.argmax(hint_pool_sims[:, :choose_k - 1], 1))
no_teacher_idxs = tf.gather(
tf.reshape(hint_pool_idxs, [-1]),
nearest_neighbor + choose_k * tf.range(batch_size))
with tf.device(self.var_cache_device):
result = tf.gather(self.mem_vals, tf.reshape(no_teacher_idxs, [-1]))
if not output_given:
teacher_loss = None
return result, mask, teacher_loss
# prepare hints from the teacher on hint pool # prepare hints from the teacher on hint pool
teacher_hints = tf.to_float( teacher_hints = tf.to_float(
tf.abs(tf.expand_dims(intended_output, 1) - hint_pool_mem_vals)) tf.abs(tf.expand_dims(intended_output, 1) - hint_pool_mem_vals))
...@@ -192,13 +207,6 @@ class Memory(object): ...@@ -192,13 +207,6 @@ class Memory(object):
teacher_vals *= ( teacher_vals *= (
1 - tf.to_float(tf.equal(0.0, tf.reduce_sum(teacher_hints, 1)))) 1 - tf.to_float(tf.equal(0.0, tf.reduce_sum(teacher_hints, 1))))
# prepare returned values
nearest_neighbor = tf.to_int32(
tf.argmax(hint_pool_sims[:, :choose_k - 1], 1))
no_teacher_idxs = tf.gather(
tf.reshape(hint_pool_idxs, [-1]),
nearest_neighbor + choose_k * tf.range(batch_size))
# we'll determine whether to do an update to memory based on whether # we'll determine whether to do an update to memory based on whether
# memory was queried correctly # memory was queried correctly
sliced_hints = tf.slice(teacher_hints, [0, 0], [-1, self.correct_in_top]) sliced_hints = tf.slice(teacher_hints, [0, 0], [-1, self.correct_in_top])
...@@ -208,9 +216,6 @@ class Memory(object): ...@@ -208,9 +216,6 @@ class Memory(object):
teacher_loss = (tf.nn.relu(neg_teacher_vals - teacher_vals + self.alpha) teacher_loss = (tf.nn.relu(neg_teacher_vals - teacher_vals + self.alpha)
- self.alpha) - self.alpha)
with tf.device(self.var_cache_device):
result = tf.gather(self.mem_vals, tf.reshape(no_teacher_idxs, [-1]))
# prepare memory updates # prepare memory updates
update_keys = normalized_query update_keys = normalized_query
update_vals = intended_output update_vals = intended_output
......
...@@ -178,27 +178,13 @@ class Model(object): ...@@ -178,27 +178,13 @@ class Model(object):
self.x, self.y = self.get_xy_placeholders() self.x, self.y = self.get_xy_placeholders()
# This context creates variables
with tf.variable_scope('core', reuse=None): with tf.variable_scope('core', reuse=None):
self.loss, self.gradient_ops = self.train(self.x, self.y) self.loss, self.gradient_ops = self.train(self.x, self.y)
# And this one re-uses them (thus the `reuse=True`)
with tf.variable_scope('core', reuse=True): with tf.variable_scope('core', reuse=True):
self.y_preds = self.eval(self.x, self.y) self.y_preds = self.eval(self.x, self.y)
# setup memory "reset" ops
(self.mem_keys, self.mem_vals,
self.mem_age, self.recent_idx) = self.memory.get()
self.mem_keys_reset = tf.placeholder(self.mem_keys.dtype,
tf.identity(self.mem_keys).shape)
self.mem_vals_reset = tf.placeholder(self.mem_vals.dtype,
tf.identity(self.mem_vals).shape)
self.mem_age_reset = tf.placeholder(self.mem_age.dtype,
tf.identity(self.mem_age).shape)
self.recent_idx_reset = tf.placeholder(self.recent_idx.dtype,
tf.identity(self.recent_idx).shape)
self.mem_reset_op = self.memory.set(self.mem_keys_reset,
self.mem_vals_reset,
self.mem_age_reset,
None)
def training_ops(self, loss): def training_ops(self, loss):
opt = self.get_optimizer() opt = self.get_optimizer()
params = tf.trainable_variables() params = tf.trainable_variables()
...@@ -254,8 +240,14 @@ class Model(object): ...@@ -254,8 +240,14 @@ class Model(object):
Predicted y. Predicted y.
""" """
cur_memory = sess.run([self.mem_keys, self.mem_vals, # Storing current memory state to restore it after prediction
self.mem_age]) mem_keys, mem_vals, mem_age, _ = self.memory.get()
cur_memory = (
tf.identity(mem_keys),
tf.identity(mem_vals),
tf.identity(mem_age),
None,
)
outputs = [self.y_preds] outputs = [self.y_preds]
if y is None: if y is None:
...@@ -263,10 +255,8 @@ class Model(object): ...@@ -263,10 +255,8 @@ class Model(object):
else: else:
ret = sess.run(outputs, feed_dict={self.x: x, self.y: y}) ret = sess.run(outputs, feed_dict={self.x: x, self.y: y})
sess.run([self.mem_reset_op], # Restoring memory state
feed_dict={self.mem_keys_reset: cur_memory[0], self.memory.set(*cur_memory)
self.mem_vals_reset: cur_memory[1],
self.mem_age_reset: cur_memory[2]})
return ret return ret
...@@ -284,8 +274,14 @@ class Model(object): ...@@ -284,8 +274,14 @@ class Model(object):
List of predicted y. List of predicted y.
""" """
cur_memory = sess.run([self.mem_keys, self.mem_vals, # Storing current memory state to restore it after prediction
self.mem_age]) mem_keys, mem_vals, mem_age, _ = self.memory.get()
cur_memory = (
tf.identity(mem_keys),
tf.identity(mem_vals),
tf.identity(mem_age),
None,
)
if clear_memory: if clear_memory:
self.clear_memory(sess) self.clear_memory(sess)
...@@ -297,10 +293,8 @@ class Model(object): ...@@ -297,10 +293,8 @@ class Model(object):
y_pred = out[0] y_pred = out[0]
y_preds.append(y_pred) y_preds.append(y_pred)
sess.run([self.mem_reset_op], # Restoring memory state
feed_dict={self.mem_keys_reset: cur_memory[0], self.memory.set(*cur_memory)
self.mem_vals_reset: cur_memory[1],
self.mem_age_reset: cur_memory[2]})
return y_preds return y_preds
......
...@@ -208,17 +208,16 @@ class Trainer(object): ...@@ -208,17 +208,16 @@ class Trainer(object):
correct.append(self.compute_correct(np.array(y), y_preds)) correct.append(self.compute_correct(np.array(y), y_preds))
# compute per-shot accuracies # compute per-shot accuracies
seen_counts = [[0] * episode_width for _ in xrange(batch_size)] seen_counts = [0] * episode_width
# loop over episode steps # loop over episode steps
for yy, yy_preds in zip(y, y_preds): for yy, yy_preds in zip(y, y_preds):
# loop over batch examples # loop over batch examples
for k, (yyy, yyy_preds) in enumerate(zip(yy, yy_preds)): yyy, yyy_preds = int(yy[0]), int(yy_preds[0])
yyy, yyy_preds = int(yyy), int(yyy_preds) count = seen_counts[yyy % episode_width]
count = seen_counts[k][yyy % episode_width] if count in correct_by_shot:
if count in correct_by_shot: correct_by_shot[count].append(
correct_by_shot[count].append( self.individual_compute_correct(yyy, yyy_preds))
self.individual_compute_correct(yyy, yyy_preds)) seen_counts[yyy % episode_width] = count + 1
seen_counts[k][yyy % episode_width] = count + 1
logging.info('validation overall accuracy %f', np.mean(correct)) logging.info('validation overall accuracy %f', np.mean(correct))
logging.info('%d-shot: %.3f, ' * num_shots, logging.info('%d-shot: %.3f, ' * num_shots,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment