Commit c173bd9b authored by MTDzi's avatar MTDzi
Browse files

Changes to "Learning to Remember..." to make it runnable in Python3.5

parent 1453d070
...@@ -20,10 +20,10 @@ Simply call ...@@ -20,10 +20,10 @@ Simply call
python data_utils.py python data_utils.py
""" """
import cPickle as pickle
import logging import logging
import os import os
import subprocess import subprocess
from six.moves import cPickle as pickle
import numpy as np import numpy as np
from scipy.misc import imresize from scipy.misc import imresize
...@@ -54,9 +54,9 @@ def get_data(): ...@@ -54,9 +54,9 @@ def get_data():
Train and test data as dictionaries mapping Train and test data as dictionaries mapping
label to list of examples. label to list of examples.
""" """
with tf.gfile.GFile(DATA_FILE_FORMAT % 'train') as f: with tf.gfile.GFile(DATA_FILE_FORMAT % 'train', 'rb') as f:
processed_train_data = pickle.load(f) processed_train_data = pickle.load(f)
with tf.gfile.GFile(DATA_FILE_FORMAT % 'test') as f: with tf.gfile.GFile(DATA_FILE_FORMAT % 'test', 'rb') as f:
processed_test_data = pickle.load(f) processed_test_data = pickle.load(f)
train_data = {} train_data = {}
...@@ -72,9 +72,9 @@ def get_data(): ...@@ -72,9 +72,9 @@ def get_data():
intersection = set(train_data.keys()) & set(test_data.keys()) intersection = set(train_data.keys()) & set(test_data.keys())
assert not intersection, 'Train and test data intersect.' assert not intersection, 'Train and test data intersect.'
ok_num_examples = [len(ll) == 20 for _, ll in train_data.iteritems()] ok_num_examples = [len(ll) == 20 for _, ll in train_data.items()]
assert all(ok_num_examples), 'Bad number of examples in train data.' assert all(ok_num_examples), 'Bad number of examples in train data.'
ok_num_examples = [len(ll) == 20 for _, ll in test_data.iteritems()] ok_num_examples = [len(ll) == 20 for _, ll in test_data.items()]
assert all(ok_num_examples), 'Bad number of examples in test data.' assert all(ok_num_examples), 'Bad number of examples in test data.'
logging.info('Number of labels in train data: %d.', len(train_data)) logging.info('Number of labels in train data: %d.', len(train_data))
......
...@@ -112,7 +112,7 @@ class Trainer(object): ...@@ -112,7 +112,7 @@ class Trainer(object):
remainders = [0] * (episode_width - remainder) + [1] * remainder remainders = [0] * (episode_width - remainder) + [1] * remainder
episode_x = [ episode_x = [
random.sample(data[lab], random.sample(data[lab],
r + (episode_length - remainder) / episode_width) r + (episode_length - remainder) // episode_width)
for lab, r in zip(episode_labels, remainders)] for lab, r in zip(episode_labels, remainders)]
episode = sum([[(x, i, ii) for ii, x in enumerate(xx)] episode = sum([[(x, i, ii) for ii, x in enumerate(xx)]
for i, xx in enumerate(episode_x)], []) for i, xx in enumerate(episode_x)], [])
...@@ -160,9 +160,9 @@ class Trainer(object): ...@@ -160,9 +160,9 @@ class Trainer(object):
logging.info('batch_size %d', batch_size) logging.info('batch_size %d', batch_size)
assert all(len(v) >= float(episode_length) / episode_width assert all(len(v) >= float(episode_length) / episode_width
for v in train_data.itervalues()) for v in train_data.values())
assert all(len(v) >= float(episode_length) / episode_width assert all(len(v) >= float(episode_length) / episode_width
for v in valid_data.itervalues()) for v in valid_data.values())
output_dim = episode_width output_dim = episode_width
self.model = self.get_model() self.model = self.get_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment