Unverified Commit 92083555 authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #3534 from MTDzi/master

Fix to Memory.query + other small changes in Learning to Remember Rare Events
parents 95385809 c173bd9b
...@@ -20,10 +20,10 @@ Simply call ...@@ -20,10 +20,10 @@ Simply call
python data_utils.py python data_utils.py
""" """
import cPickle as pickle
import logging import logging
import os import os
import subprocess import subprocess
from six.moves import cPickle as pickle
import numpy as np import numpy as np
from scipy.misc import imresize from scipy.misc import imresize
...@@ -54,9 +54,9 @@ def get_data(): ...@@ -54,9 +54,9 @@ def get_data():
Train and test data as dictionaries mapping Train and test data as dictionaries mapping
label to list of examples. label to list of examples.
""" """
with tf.gfile.GFile(DATA_FILE_FORMAT % 'train') as f: with tf.gfile.GFile(DATA_FILE_FORMAT % 'train', 'rb') as f:
processed_train_data = pickle.load(f) processed_train_data = pickle.load(f)
with tf.gfile.GFile(DATA_FILE_FORMAT % 'test') as f: with tf.gfile.GFile(DATA_FILE_FORMAT % 'test', 'rb') as f:
processed_test_data = pickle.load(f) processed_test_data = pickle.load(f)
train_data = {} train_data = {}
...@@ -72,9 +72,9 @@ def get_data(): ...@@ -72,9 +72,9 @@ def get_data():
intersection = set(train_data.keys()) & set(test_data.keys()) intersection = set(train_data.keys()) & set(test_data.keys())
assert not intersection, 'Train and test data intersect.' assert not intersection, 'Train and test data intersect.'
ok_num_examples = [len(ll) == 20 for _, ll in train_data.iteritems()] ok_num_examples = [len(ll) == 20 for _, ll in train_data.items()]
assert all(ok_num_examples), 'Bad number of examples in train data.' assert all(ok_num_examples), 'Bad number of examples in train data.'
ok_num_examples = [len(ll) == 20 for _, ll in test_data.iteritems()] ok_num_examples = [len(ll) == 20 for _, ll in test_data.items()]
assert all(ok_num_examples), 'Bad number of examples in test data.' assert all(ok_num_examples), 'Bad number of examples in test data.'
logging.info('Number of labels in train data: %d.', len(train_data)) logging.info('Number of labels in train data: %d.', len(train_data))
......
...@@ -173,6 +173,21 @@ class Memory(object): ...@@ -173,6 +173,21 @@ class Memory(object):
softmax_temp = max(1.0, np.log(0.2 * self.choose_k) / self.alpha) softmax_temp = max(1.0, np.log(0.2 * self.choose_k) / self.alpha)
mask = tf.nn.softmax(hint_pool_sims[:, :choose_k - 1] * softmax_temp) mask = tf.nn.softmax(hint_pool_sims[:, :choose_k - 1] * softmax_temp)
# prepare returned values
nearest_neighbor = tf.to_int32(
tf.argmax(hint_pool_sims[:, :choose_k - 1], 1))
no_teacher_idxs = tf.gather(
tf.reshape(hint_pool_idxs, [-1]),
nearest_neighbor + choose_k * tf.range(batch_size))
with tf.device(self.var_cache_device):
result = tf.gather(self.mem_vals, tf.reshape(no_teacher_idxs, [-1]))
if not output_given:
teacher_loss = None
return result, mask, teacher_loss
# prepare hints from the teacher on hint pool # prepare hints from the teacher on hint pool
teacher_hints = tf.to_float( teacher_hints = tf.to_float(
tf.abs(tf.expand_dims(intended_output, 1) - hint_pool_mem_vals)) tf.abs(tf.expand_dims(intended_output, 1) - hint_pool_mem_vals))
...@@ -192,13 +207,6 @@ class Memory(object): ...@@ -192,13 +207,6 @@ class Memory(object):
teacher_vals *= ( teacher_vals *= (
1 - tf.to_float(tf.equal(0.0, tf.reduce_sum(teacher_hints, 1)))) 1 - tf.to_float(tf.equal(0.0, tf.reduce_sum(teacher_hints, 1))))
# prepare returned values
nearest_neighbor = tf.to_int32(
tf.argmax(hint_pool_sims[:, :choose_k - 1], 1))
no_teacher_idxs = tf.gather(
tf.reshape(hint_pool_idxs, [-1]),
nearest_neighbor + choose_k * tf.range(batch_size))
# we'll determine whether to do an update to memory based on whether # we'll determine whether to do an update to memory based on whether
# memory was queried correctly # memory was queried correctly
sliced_hints = tf.slice(teacher_hints, [0, 0], [-1, self.correct_in_top]) sliced_hints = tf.slice(teacher_hints, [0, 0], [-1, self.correct_in_top])
...@@ -208,9 +216,6 @@ class Memory(object): ...@@ -208,9 +216,6 @@ class Memory(object):
teacher_loss = (tf.nn.relu(neg_teacher_vals - teacher_vals + self.alpha) teacher_loss = (tf.nn.relu(neg_teacher_vals - teacher_vals + self.alpha)
- self.alpha) - self.alpha)
with tf.device(self.var_cache_device):
result = tf.gather(self.mem_vals, tf.reshape(no_teacher_idxs, [-1]))
# prepare memory updates # prepare memory updates
update_keys = normalized_query update_keys = normalized_query
update_vals = intended_output update_vals = intended_output
......
...@@ -178,27 +178,13 @@ class Model(object): ...@@ -178,27 +178,13 @@ class Model(object):
self.x, self.y = self.get_xy_placeholders() self.x, self.y = self.get_xy_placeholders()
# This context creates variables
with tf.variable_scope('core', reuse=None): with tf.variable_scope('core', reuse=None):
self.loss, self.gradient_ops = self.train(self.x, self.y) self.loss, self.gradient_ops = self.train(self.x, self.y)
# And this one re-uses them (thus the `reuse=True`)
with tf.variable_scope('core', reuse=True): with tf.variable_scope('core', reuse=True):
self.y_preds = self.eval(self.x, self.y) self.y_preds = self.eval(self.x, self.y)
# setup memory "reset" ops
(self.mem_keys, self.mem_vals,
self.mem_age, self.recent_idx) = self.memory.get()
self.mem_keys_reset = tf.placeholder(self.mem_keys.dtype,
tf.identity(self.mem_keys).shape)
self.mem_vals_reset = tf.placeholder(self.mem_vals.dtype,
tf.identity(self.mem_vals).shape)
self.mem_age_reset = tf.placeholder(self.mem_age.dtype,
tf.identity(self.mem_age).shape)
self.recent_idx_reset = tf.placeholder(self.recent_idx.dtype,
tf.identity(self.recent_idx).shape)
self.mem_reset_op = self.memory.set(self.mem_keys_reset,
self.mem_vals_reset,
self.mem_age_reset,
None)
def training_ops(self, loss): def training_ops(self, loss):
opt = self.get_optimizer() opt = self.get_optimizer()
params = tf.trainable_variables() params = tf.trainable_variables()
...@@ -254,8 +240,14 @@ class Model(object): ...@@ -254,8 +240,14 @@ class Model(object):
Predicted y. Predicted y.
""" """
cur_memory = sess.run([self.mem_keys, self.mem_vals, # Storing current memory state to restore it after prediction
self.mem_age]) mem_keys, mem_vals, mem_age, _ = self.memory.get()
cur_memory = (
tf.identity(mem_keys),
tf.identity(mem_vals),
tf.identity(mem_age),
None,
)
outputs = [self.y_preds] outputs = [self.y_preds]
if y is None: if y is None:
...@@ -263,10 +255,8 @@ class Model(object): ...@@ -263,10 +255,8 @@ class Model(object):
else: else:
ret = sess.run(outputs, feed_dict={self.x: x, self.y: y}) ret = sess.run(outputs, feed_dict={self.x: x, self.y: y})
sess.run([self.mem_reset_op], # Restoring memory state
feed_dict={self.mem_keys_reset: cur_memory[0], self.memory.set(*cur_memory)
self.mem_vals_reset: cur_memory[1],
self.mem_age_reset: cur_memory[2]})
return ret return ret
...@@ -284,8 +274,14 @@ class Model(object): ...@@ -284,8 +274,14 @@ class Model(object):
List of predicted y. List of predicted y.
""" """
cur_memory = sess.run([self.mem_keys, self.mem_vals, # Storing current memory state to restore it after prediction
self.mem_age]) mem_keys, mem_vals, mem_age, _ = self.memory.get()
cur_memory = (
tf.identity(mem_keys),
tf.identity(mem_vals),
tf.identity(mem_age),
None,
)
if clear_memory: if clear_memory:
self.clear_memory(sess) self.clear_memory(sess)
...@@ -297,10 +293,8 @@ class Model(object): ...@@ -297,10 +293,8 @@ class Model(object):
y_pred = out[0] y_pred = out[0]
y_preds.append(y_pred) y_preds.append(y_pred)
sess.run([self.mem_reset_op], # Restoring memory state
feed_dict={self.mem_keys_reset: cur_memory[0], self.memory.set(*cur_memory)
self.mem_vals_reset: cur_memory[1],
self.mem_age_reset: cur_memory[2]})
return y_preds return y_preds
......
...@@ -112,7 +112,7 @@ class Trainer(object): ...@@ -112,7 +112,7 @@ class Trainer(object):
remainders = [0] * (episode_width - remainder) + [1] * remainder remainders = [0] * (episode_width - remainder) + [1] * remainder
episode_x = [ episode_x = [
random.sample(data[lab], random.sample(data[lab],
r + (episode_length - remainder) / episode_width) r + (episode_length - remainder) // episode_width)
for lab, r in zip(episode_labels, remainders)] for lab, r in zip(episode_labels, remainders)]
episode = sum([[(x, i, ii) for ii, x in enumerate(xx)] episode = sum([[(x, i, ii) for ii, x in enumerate(xx)]
for i, xx in enumerate(episode_x)], []) for i, xx in enumerate(episode_x)], [])
...@@ -160,9 +160,9 @@ class Trainer(object): ...@@ -160,9 +160,9 @@ class Trainer(object):
logging.info('batch_size %d', batch_size) logging.info('batch_size %d', batch_size)
assert all(len(v) >= float(episode_length) / episode_width assert all(len(v) >= float(episode_length) / episode_width
for v in train_data.itervalues()) for v in train_data.values())
assert all(len(v) >= float(episode_length) / episode_width assert all(len(v) >= float(episode_length) / episode_width
for v in valid_data.itervalues()) for v in valid_data.values())
output_dim = episode_width output_dim = episode_width
self.model = self.get_model() self.model = self.get_model()
...@@ -208,17 +208,16 @@ class Trainer(object): ...@@ -208,17 +208,16 @@ class Trainer(object):
correct.append(self.compute_correct(np.array(y), y_preds)) correct.append(self.compute_correct(np.array(y), y_preds))
# compute per-shot accuracies # compute per-shot accuracies
seen_counts = [[0] * episode_width for _ in xrange(batch_size)] seen_counts = [0] * episode_width
# loop over episode steps # loop over episode steps
for yy, yy_preds in zip(y, y_preds): for yy, yy_preds in zip(y, y_preds):
# loop over batch examples # loop over batch examples
for k, (yyy, yyy_preds) in enumerate(zip(yy, yy_preds)): yyy, yyy_preds = int(yy[0]), int(yy_preds[0])
yyy, yyy_preds = int(yyy), int(yyy_preds) count = seen_counts[yyy % episode_width]
count = seen_counts[k][yyy % episode_width]
if count in correct_by_shot: if count in correct_by_shot:
correct_by_shot[count].append( correct_by_shot[count].append(
self.individual_compute_correct(yyy, yyy_preds)) self.individual_compute_correct(yyy, yyy_preds))
seen_counts[k][yyy % episode_width] = count + 1 seen_counts[yyy % episode_width] = count + 1
logging.info('validation overall accuracy %f', np.mean(correct)) logging.info('validation overall accuracy %f', np.mean(correct))
logging.info('%d-shot: %.3f, ' * num_shots, logging.info('%d-shot: %.3f, ' * num_shots,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment