Commit b64919aa authored by MTDzi's avatar MTDzi
Browse files

Corrected number of shots in learning to remember rare events

parent 4d1f67cf
......@@ -197,7 +197,8 @@ class Trainer(object):
# validation
correct = []
correct_by_shot = dict((k, []) for k in xrange(self.episode_width + 1))
num_shots = episode_length // episode_width
correct_by_shot = dict((k, []) for k in xrange(num_shots))
for _ in xrange(FLAGS.validation_length):
x, y = self.sample_episode_batch(
valid_data, episode_length, episode_width, 1)
......@@ -213,16 +214,16 @@ class Trainer(object):
# loop over batch examples
for k, (yyy, yyy_preds) in enumerate(zip(yy, yy_preds)):
yyy, yyy_preds = int(yyy), int(yyy_preds)
count = seen_counts[k][yyy % self.episode_width]
count = seen_counts[k][yyy % episode_width]
if count in correct_by_shot:
correct_by_shot[count].append(
self.individual_compute_correct(yyy, yyy_preds))
seen_counts[k][yyy % self.episode_width] = count + 1
seen_counts[k][yyy % episode_width] = count + 1
logging.info('validation overall accuracy %f', np.mean(correct))
logging.info('%d-shot: %.3f, ' * (self.episode_width + 1),
logging.info('%d-shot: %.3f, ' * num_shots,
*sum([[k, np.mean(correct_by_shot[k])]
for k in xrange(self.episode_width + 1)], []))
for k in xrange(num_shots)], []))
if saver and FLAGS.save_dir:
saved_file = saver.save(sess,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment