Commit c173bd9b authored by MTDzi's avatar MTDzi
Browse files

Changes to "Learning to Remember..." to make it runnable in Python3.5

parent 1453d070
......@@ -20,10 +20,10 @@ Simply call
python data_utils.py
"""
import cPickle as pickle
import logging
import os
import subprocess
from six.moves import cPickle as pickle
import numpy as np
from scipy.misc import imresize
......@@ -54,9 +54,9 @@ def get_data():
Train and test data as dictionaries mapping
label to list of examples.
"""
with tf.gfile.GFile(DATA_FILE_FORMAT % 'train') as f:
with tf.gfile.GFile(DATA_FILE_FORMAT % 'train', 'rb') as f:
processed_train_data = pickle.load(f)
with tf.gfile.GFile(DATA_FILE_FORMAT % 'test') as f:
with tf.gfile.GFile(DATA_FILE_FORMAT % 'test', 'rb') as f:
processed_test_data = pickle.load(f)
train_data = {}
......@@ -72,9 +72,9 @@ def get_data():
intersection = set(train_data.keys()) & set(test_data.keys())
assert not intersection, 'Train and test data intersect.'
ok_num_examples = [len(ll) == 20 for _, ll in train_data.iteritems()]
ok_num_examples = [len(ll) == 20 for _, ll in train_data.items()]
assert all(ok_num_examples), 'Bad number of examples in train data.'
ok_num_examples = [len(ll) == 20 for _, ll in test_data.iteritems()]
ok_num_examples = [len(ll) == 20 for _, ll in test_data.items()]
assert all(ok_num_examples), 'Bad number of examples in test data.'
logging.info('Number of labels in train data: %d.', len(train_data))
......
......@@ -112,7 +112,7 @@ class Trainer(object):
remainders = [0] * (episode_width - remainder) + [1] * remainder
episode_x = [
random.sample(data[lab],
r + (episode_length - remainder) / episode_width)
r + (episode_length - remainder) // episode_width)
for lab, r in zip(episode_labels, remainders)]
episode = sum([[(x, i, ii) for ii, x in enumerate(xx)]
for i, xx in enumerate(episode_x)], [])
......@@ -160,9 +160,9 @@ class Trainer(object):
logging.info('batch_size %d', batch_size)
assert all(len(v) >= float(episode_length) / episode_width
for v in train_data.itervalues())
for v in train_data.values())
assert all(len(v) >= float(episode_length) / episode_width
for v in valid_data.itervalues())
for v in valid_data.values())
output_dim = episode_width
self.model = self.get_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment