Commit cf6b9cee authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #1097 from ofirnachum/master

add learning to remember rare events
parents bc70271a 6a9c0da9
Code for the Memory Module as described
in "Learning to Remember Rare Events" by
Lukasz Kaiser, Ofir Nachum, Aurko Roy, and Samy Bengio
published as a conference paper at ICLR 2017.
Requirements:
* TensorFlow (see tensorflow.org for how to install)
* Some basic command-line utilities (git, unzip).
Description:
The general memory module is located in memory.py.
Some code is provided to see the memory module in
action on the standard Omniglot dataset.
Download and setup the dataset using data_utils.py
and then run the training script train.py
(see example commands below).
Note that the structure and parameters of the model
are optimized for the data preparation as provided.
Quick Start:
First download and set-up Omniglot data by running
```
python data_utils.py
```
Then run the training script:
```
python train.py --memory_size=8192 \
--batch_size=16 --validation_length=50 \
--episode_width=5 --episode_length=30
```
The first validation batch may look like this (although it is noisy):
```
0-shot: 0.040, 1-shot: 0.404, 2-shot: 0.516, 3-shot: 0.604,
4-shot: 0.656, 5-shot: 0.684
```
At step 500 you may see something like this:
```
0-shot: 0.036, 1-shot: 0.836, 2-shot: 0.900, 3-shot: 0.940,
4-shot: 0.944, 5-shot: 0.916
```
At step 4000 you may see something like this:
```
0-shot: 0.044, 1-shot: 0.960, 2-shot: 1.000, 3-shot: 0.988,
4-shot: 0.972, 5-shot: 0.992
```
Maintained by Ofir Nachum (ofirnachum) and
Lukasz Kaiser (lukaszkaiser).
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""Data loading and other utilities.
Use this file to first copy over and pre-process the Omniglot dataset.
Simply call
python data_utils.py
"""
import cPickle as pickle
import logging
import os
import subprocess
import numpy as np
from scipy.misc import imresize
from scipy.misc import imrotate
from scipy.ndimage import imread
import tensorflow as tf
MAIN_DIR = ''
REPO_LOCATION = 'https://github.com/brendenlake/omniglot.git'
REPO_DIR = os.path.join(MAIN_DIR, 'omniglot')
DATA_DIR = os.path.join(REPO_DIR, 'python')
TRAIN_DIR = os.path.join(DATA_DIR, 'images_background')
TEST_DIR = os.path.join(DATA_DIR, 'images_evaluation')
DATA_FILE_FORMAT = os.path.join(MAIN_DIR, '%s_omni.pkl')
TRAIN_ROTATIONS = True # augment training data with rotations
TEST_ROTATIONS = False # augment testing data with rotations
IMAGE_ORIGINAL_SIZE = 105
IMAGE_NEW_SIZE = 28
def get_data():
"""Get data in form suitable for episodic training.
Returns:
Train and test data as dictionaries mapping
label to list of examples.
"""
with tf.gfile.GFile(DATA_FILE_FORMAT % 'train') as f:
processed_train_data = pickle.load(f)
with tf.gfile.GFile(DATA_FILE_FORMAT % 'test') as f:
processed_test_data = pickle.load(f)
train_data = {}
test_data = {}
for data, processed_data in zip([train_data, test_data],
[processed_train_data, processed_test_data]):
for image, label in zip(processed_data['images'],
processed_data['labels']):
if label not in data:
data[label] = []
data[label].append(image.reshape([-1]).astype('float32'))
intersection = set(train_data.keys()) & set(test_data.keys())
assert not intersection, 'Train and test data intersect.'
ok_num_examples = [len(ll) == 20 for _, ll in train_data.iteritems()]
assert all(ok_num_examples), 'Bad number of examples in train data.'
ok_num_examples = [len(ll) == 20 for _, ll in test_data.iteritems()]
assert all(ok_num_examples), 'Bad number of examples in test data.'
logging.info('Number of labels in train data: %d.', len(train_data))
logging.info('Number of labels in test data: %d.', len(test_data))
return train_data, test_data
def crawl_directory(directory, augment_with_rotations=False,
first_label=0):
"""Crawls data directory and returns stuff."""
label_idx = first_label
images = []
labels = []
info = []
# traverse root directory
for root, _, files in os.walk(directory):
logging.info('Reading files from %s', root)
fileflag = 0
for file_name in files:
full_file_name = os.path.join(root, file_name)
img = imread(full_file_name, flatten=True)
for i, angle in enumerate([0, 90, 180, 270]):
if not augment_with_rotations and i > 0:
break
images.append(imrotate(img, angle))
labels.append(label_idx + i)
info.append(full_file_name)
fileflag = 1
if fileflag:
label_idx += 4 if augment_with_rotations else 1
return images, labels, info
def resize_images(images, new_width, new_height):
"""Resize images to new dimensions."""
resized_images = np.zeros([images.shape[0], new_width, new_height],
dtype=np.float32)
for i in range(images.shape[0]):
resized_images[i, :, :] = imresize(images[i, :, :],
[new_width, new_height],
interp='bilinear',
mode=None)
return resized_images
def write_datafiles(directory, write_file,
resize=True, rotate=False,
new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE,
first_label=0):
"""Load and preprocess images from a directory and write them to a file.
Args:
directory: Directory of alphabet sub-directories.
write_file: Filename to write to.
resize: Whether to resize the images.
rotate: Whether to augment the dataset with rotations.
new_width: New resize width.
new_height: New resize height.
first_label: Label to start with.
Returns:
Number of new labels created.
"""
# these are the default sizes for Omniglot:
imgwidth = IMAGE_ORIGINAL_SIZE
imgheight = IMAGE_ORIGINAL_SIZE
logging.info('Reading the data.')
images, labels, info = crawl_directory(directory,
augment_with_rotations=rotate,
first_label=first_label)
images_np = np.zeros([len(images), imgwidth, imgheight], dtype=np.bool)
labels_np = np.zeros([len(labels)], dtype=np.uint32)
for i in xrange(len(images)):
images_np[i, :, :] = images[i]
labels_np[i] = labels[i]
if resize:
logging.info('Resizing images.')
resized_images = resize_images(images_np, new_width, new_height)
logging.info('Writing resized data in float32 format.')
data = {'images': resized_images,
'labels': labels_np,
'info': info}
with tf.gfile.GFile(write_file, 'w') as f:
pickle.dump(data, f)
else:
logging.info('Writing original sized data in boolean format.')
data = {'images': images_np,
'labels': labels_np,
'info': info}
with tf.gfile.GFile(write_file, 'w') as f:
pickle.dump(data, f)
return len(np.unique(labels_np))
def maybe_download_data():
"""Download Omniglot repo if it does not exist."""
if os.path.exists(REPO_DIR):
logging.info('It appears that Git repo already exists.')
else:
logging.info('It appears that Git repo does not exist.')
logging.info('Cloning now.')
subprocess.check_output('git clone %s' % REPO_LOCATION, shell=True)
if os.path.exists(TRAIN_DIR):
logging.info('It appears that train data has already been unzipped.')
else:
logging.info('It appears that train data has not been unzipped.')
logging.info('Unzipping now.')
subprocess.check_output('unzip %s.zip -d %s' % (TRAIN_DIR, DATA_DIR),
shell=True)
if os.path.exists(TEST_DIR):
logging.info('It appears that test data has already been unzipped.')
else:
logging.info('It appears that test data has not been unzipped.')
logging.info('Unzipping now.')
subprocess.check_output('unzip %s.zip -d %s' % (TEST_DIR, DATA_DIR),
shell=True)
def preprocess_omniglot():
"""Download and prepare raw Omniglot data.
Downloads the data from GitHub if it does not exist.
Then load the images, augment with rotations if desired.
Resize the images and write them to a pickle file.
"""
maybe_download_data()
directory = TRAIN_DIR
write_file = DATA_FILE_FORMAT % 'train'
num_labels = write_datafiles(
directory, write_file, resize=True, rotate=TRAIN_ROTATIONS,
new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE)
directory = TEST_DIR
write_file = DATA_FILE_FORMAT % 'test'
write_datafiles(directory, write_file, resize=True, rotate=TEST_ROTATIONS,
new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE,
first_label=num_labels)
def main(unused_argv):
logging.basicConfig(level=logging.INFO)
preprocess_omniglot()
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""Memory module for storing "nearest neighbors".
Implements a key-value memory for generalized one-shot learning
as described in the paper
"Learning to Remember Rare Events"
by Lukasz Kaiser, Ofir Nachum, Aurko Roy, Samy Bengio,
published as a conference paper at ICLR 2017.
"""
import numpy as np
import tensorflow as tf
class Memory(object):
"""Memory module."""
def __init__(self, key_dim, memory_size, vocab_size,
choose_k=256, alpha=0.1, correct_in_top=1, age_noise=8.0,
var_cache_device='', nn_device=''):
self.key_dim = key_dim
self.memory_size = memory_size
self.vocab_size = vocab_size
self.choose_k = min(choose_k, memory_size)
self.alpha = alpha
self.correct_in_top = correct_in_top
self.age_noise = age_noise
self.var_cache_device = var_cache_device # Variables are cached here.
self.nn_device = nn_device # Device to perform nearest neighbour matmul.
caching_device = var_cache_device if var_cache_device else None
self.update_memory = tf.constant(True) # Can be fed "false" if needed.
self.mem_keys = tf.get_variable(
'memkeys', [self.memory_size, self.key_dim], trainable=False,
initializer=tf.random_uniform_initializer(-0.0, 0.0),
caching_device=caching_device)
self.mem_vals = tf.get_variable(
'memvals', [self.memory_size], dtype=tf.int32, trainable=False,
initializer=tf.constant_initializer(0, tf.int32),
caching_device=caching_device)
self.mem_age = tf.get_variable(
'memage', [self.memory_size], dtype=tf.float32, trainable=False,
initializer=tf.constant_initializer(0.0), caching_device=caching_device)
self.recent_idx = tf.get_variable(
'recent_idx', [self.vocab_size], dtype=tf.int32, trainable=False,
initializer=tf.constant_initializer(0, tf.int32))
# variable for projecting query vector into memory key
self.query_proj = tf.get_variable(
'memory_query_proj', [self.key_dim, self.key_dim], dtype=tf.float32,
initializer=tf.truncated_normal_initializer(0, 0.01),
caching_device=caching_device)
def get(self):
return self.mem_keys, self.mem_vals, self.mem_age, self.recent_idx
def set(self, k, v, a, r=None):
return tf.group(
self.mem_keys.assign(k),
self.mem_vals.assign(v),
self.mem_age.assign(a),
(self.recent_idx.assign(r) if r is not None else tf.group()))
def clear(self):
return tf.variables_initializer([self.mem_keys, self.mem_vals, self.mem_age,
self.recent_idx])
def get_hint_pool_idxs(self, normalized_query):
"""Get small set of idxs to compute nearest neighbor queries on.
This is an expensive look-up on the whole memory that is used to
avoid more expensive operations later on.
Args:
normalized_query: A Tensor of shape [None, key_dim].
Returns:
A Tensor of shape [None, choose_k] of indices in memory
that are closest to the queries.
"""
# look up in large memory, no gradients
with tf.device(self.nn_device):
similarities = tf.matmul(tf.stop_gradient(normalized_query),
self.mem_keys, transpose_b=True, name='nn_mmul')
_, hint_pool_idxs = tf.nn.top_k(
tf.stop_gradient(similarities), k=self.choose_k, name='nn_topk')
return hint_pool_idxs
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
batch_size, use_recent_idx, intended_output):
"""Function that creates all the update ops."""
mem_age_incr = self.mem_age.assign_add(tf.ones([self.memory_size],
dtype=tf.float32))
with tf.control_dependencies([mem_age_incr]):
mem_age_upd = tf.scatter_update(
self.mem_age, upd_idxs, tf.zeros([batch_size], dtype=tf.float32))
mem_key_upd = tf.scatter_update(
self.mem_keys, upd_idxs, upd_keys)
mem_val_upd = tf.scatter_update(
self.mem_vals, upd_idxs, upd_vals)
if use_recent_idx:
recent_idx_upd = tf.scatter_update(
self.recent_idx, intended_output, upd_idxs)
else:
recent_idx_upd = tf.group()
return tf.group(mem_age_upd, mem_key_upd, mem_val_upd, recent_idx_upd)
def query(self, query_vec, intended_output, use_recent_idx=True):
"""Queries memory for nearest neighbor.
Args:
query_vec: A batch of vectors to query (embedding of input to model).
intended_output: The values that would be the correct output of the
memory.
use_recent_idx: Whether to always insert at least one instance of a
correct memory fetch.
Returns:
A tuple (result, mask, teacher_loss).
result: The result of the memory look up.
mask: The affinity of the query to the result.
teacher_loss: The loss for training the memory module.
"""
batch_size = tf.shape(query_vec)[0]
output_given = intended_output is not None
# prepare query for memory lookup
query_vec = tf.matmul(query_vec, self.query_proj)
normalized_query = tf.nn.l2_normalize(query_vec, dim=1)
hint_pool_idxs = self.get_hint_pool_idxs(normalized_query)
if output_given and use_recent_idx: # add at least one correct memory
most_recent_hint_idx = tf.gather(self.recent_idx, intended_output)
hint_pool_idxs = tf.concat([hint_pool_idxs,
tf.expand_dims(most_recent_hint_idx, 1)], 1)
choose_k = tf.shape(hint_pool_idxs)[1]
with tf.device(self.var_cache_device):
# create small memory and look up with gradients
my_mem_keys = tf.stop_gradient(tf.gather(self.mem_keys, hint_pool_idxs,
name='my_mem_keys_gather'))
similarities = tf.matmul(tf.expand_dims(normalized_query, 1),
my_mem_keys, adjoint_b=True, name='batch_mmul')
hint_pool_sims = tf.squeeze(similarities, [1], name='hint_pool_sims')
hint_pool_mem_vals = tf.gather(self.mem_vals, hint_pool_idxs,
name='hint_pool_mem_vals')
# Calculate softmax mask on the top-k if requested.
# Softmax temperature. Say we have K elements at dist x and one at (x+a).
# Softmax of the last is e^tm(x+a)/Ke^tm*x + e^tm(x+a) = e^tm*a/K+e^tm*a.
# To make that 20% we'd need to have e^tm*a ~= 0.2K, so tm = log(0.2K)/a.
softmax_temp = max(1.0, np.log(0.2 * self.choose_k) / self.alpha)
mask = tf.nn.softmax(hint_pool_sims[:, :choose_k - 1] * softmax_temp)
# prepare hints from the teacher on hint pool
teacher_hints = tf.to_float(
tf.abs(tf.expand_dims(intended_output, 1) - hint_pool_mem_vals))
teacher_hints = 1.0 - tf.minimum(1.0, teacher_hints)
teacher_vals, teacher_hint_idxs = tf.nn.top_k(
hint_pool_sims * teacher_hints, k=1)
neg_teacher_vals, _ = tf.nn.top_k(
hint_pool_sims * (1 - teacher_hints), k=1)
# bring back idxs to full memory
teacher_idxs = tf.gather(
tf.reshape(hint_pool_idxs, [-1]),
teacher_hint_idxs[:, 0] + choose_k * tf.range(batch_size))
# zero-out teacher_vals if there are no hints
teacher_vals *= (
1 - tf.to_float(tf.equal(0.0, tf.reduce_sum(teacher_hints, 1))))
# prepare returned values
nearest_neighbor = tf.to_int32(
tf.argmax(hint_pool_sims[:, :choose_k - 1], 1))
no_teacher_idxs = tf.gather(
tf.reshape(hint_pool_idxs, [-1]),
nearest_neighbor + choose_k * tf.range(batch_size))
# we'll determine whether to do an update to memory based on whether
# memory was queried correctly
sliced_hints = tf.slice(teacher_hints, [0, 0], [-1, self.correct_in_top])
incorrect_memory_lookup = tf.equal(0.0, tf.reduce_sum(sliced_hints, 1))
# loss based on triplet loss
teacher_loss = (tf.nn.relu(neg_teacher_vals - teacher_vals + self.alpha)
- self.alpha)
with tf.device(self.var_cache_device):
result = tf.gather(self.mem_vals, tf.reshape(no_teacher_idxs, [-1]))
# prepare memory updates
update_keys = normalized_query
update_vals = intended_output
fetched_idxs = teacher_idxs # correctly fetched from memory
with tf.device(self.var_cache_device):
fetched_keys = tf.gather(self.mem_keys, fetched_idxs, name='fetched_keys')
fetched_vals = tf.gather(self.mem_vals, fetched_idxs, name='fetched_vals')
# do memory updates here
fetched_keys_upd = update_keys + fetched_keys # Momentum-like update
fetched_keys_upd = tf.nn.l2_normalize(fetched_keys_upd, dim=1)
# Randomize age a bit, e.g., to select different ones in parallel workers.
mem_age_with_noise = self.mem_age + tf.random_uniform(
[self.memory_size], - self.age_noise, self.age_noise)
_, oldest_idxs = tf.nn.top_k(mem_age_with_noise, k=batch_size, sorted=False)
with tf.control_dependencies([result]):
upd_idxs = tf.where(incorrect_memory_lookup,
oldest_idxs,
fetched_idxs)
# upd_idxs = tf.Print(upd_idxs, [upd_idxs], "UPD IDX", summarize=8)
upd_keys = tf.where(incorrect_memory_lookup,
update_keys,
fetched_keys_upd)
upd_vals = tf.where(incorrect_memory_lookup,
update_vals,
fetched_vals)
def make_update_op():
return self.make_update_op(upd_idxs, upd_keys, upd_vals,
batch_size, use_recent_idx, intended_output)
update_op = tf.cond(self.update_memory, make_update_op, tf.no_op)
with tf.control_dependencies([update_op]):
result = tf.identity(result)
mask = tf.identity(mask)
teacher_loss = tf.identity(teacher_loss)
return result, mask, tf.reduce_mean(teacher_loss)
class LSHMemory(Memory):
"""Memory employing locality sensitive hashing.
Note: Not fully tested.
"""
def __init__(self, key_dim, memory_size, vocab_size,
choose_k=256, alpha=0.1, correct_in_top=1, age_noise=8.0,
var_cache_device='', nn_device='',
num_hashes=None, num_libraries=None):
super(LSHMemory, self).__init__(
key_dim, memory_size, vocab_size,
choose_k=choose_k, alpha=alpha, correct_in_top=1, age_noise=age_noise,
var_cache_device=var_cache_device, nn_device=nn_device)
self.num_libraries = num_libraries or int(self.choose_k ** 0.5)
self.num_per_hash_slot = max(1, self.choose_k // self.num_libraries)
self.num_hashes = (num_hashes or
int(np.log2(self.memory_size / self.num_per_hash_slot)))
self.num_hashes = min(max(self.num_hashes, 1), 20)
self.num_hash_slots = 2 ** self.num_hashes
# hashing vectors
self.hash_vecs = [
tf.get_variable(
'hash_vecs%d' % i, [self.num_hashes, self.key_dim],
dtype=tf.float32, trainable=False,
initializer=tf.truncated_normal_initializer(0, 1))
for i in xrange(self.num_libraries)]
# map representing which hash slots map to which mem keys
self.hash_slots = [
tf.get_variable(
'hash_slots%d' % i, [self.num_hash_slots, self.num_per_hash_slot],
dtype=tf.int32, trainable=False,
initializer=tf.random_uniform_initializer(maxval=self.memory_size,
dtype=tf.int32))
for i in xrange(self.num_libraries)]
def get(self): # not implemented
return self.mem_keys, self.mem_vals, self.mem_age, self.recent_idx
def set(self, k, v, a, r=None): # not implemented
return tf.group(
self.mem_keys.assign(k),
self.mem_vals.assign(v),
self.mem_age.assign(a),
(self.recent_idx.assign(r) if r is not None else tf.group()))
def clear(self):
return tf.variables_initializer([self.mem_keys, self.mem_vals, self.mem_age,
self.recent_idx] + self.hash_slots)
def get_hash_slots(self, query):
"""Gets hashed-to buckets for batch of queries.
Args:
query: 2-d Tensor of query vectors.
Returns:
A list of hashed-to buckets for each hash function.
"""
binary_hash = [
tf.less(tf.matmul(query, self.hash_vecs[i], transpose_b=True), 0)
for i in xrange(self.num_libraries)]
hash_slot_idxs = [
tf.reduce_sum(
tf.to_int32(binary_hash[i]) *
tf.constant([[2 ** i for i in xrange(self.num_hashes)]],
dtype=tf.int32), 1)
for i in xrange(self.num_libraries)]
return hash_slot_idxs
def get_hint_pool_idxs(self, normalized_query):
"""Get small set of idxs to compute nearest neighbor queries on.
This is an expensive look-up on the whole memory that is used to
avoid more expensive operations later on.
Args:
normalized_query: A Tensor of shape [None, key_dim].
Returns:
A Tensor of shape [None, choose_k] of indices in memory
that are closest to the queries.
"""
# get hash of query vecs
hash_slot_idxs = self.get_hash_slots(normalized_query)
# grab mem idxs in the hash slots
hint_pool_idxs = [
tf.maximum(tf.minimum(
tf.gather(self.hash_slots[i], idxs),
self.memory_size - 1), 0)
for i, idxs in enumerate(hash_slot_idxs)]
return tf.concat(hint_pool_idxs, 1)
def make_update_op(self, upd_idxs, upd_keys, upd_vals,
batch_size, use_recent_idx, intended_output):
"""Function that creates all the update ops."""
base_update_op = super(LSHMemory, self).make_update_op(
upd_idxs, upd_keys, upd_vals,
batch_size, use_recent_idx, intended_output)
# compute hash slots to be updated
hash_slot_idxs = self.get_hash_slots(upd_keys)
# make updates
update_ops = []
with tf.control_dependencies([base_update_op]):
for i, slot_idxs in enumerate(hash_slot_idxs):
# for each slot, choose which entry to replace
entry_idx = tf.random_uniform([batch_size],
maxval=self.num_per_hash_slot,
dtype=tf.int32)
entry_mul = 1 - tf.one_hot(entry_idx, self.num_per_hash_slot,
dtype=tf.int32)
entry_add = (tf.expand_dims(upd_idxs, 1) *
tf.one_hot(entry_idx, self.num_per_hash_slot,
dtype=tf.int32))
mul_op = tf.scatter_mul(self.hash_slots[i], slot_idxs, entry_mul)
with tf.control_dependencies([mul_op]):
add_op = tf.scatter_add(self.hash_slots[i], slot_idxs, entry_add)
update_ops.append(add_op)
return tf.group(*update_ops)
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""Model using memory component.
The model embeds images using a standard CNN architecture.
These embeddings are used as keys to the memory component,
which returns nearest neighbors.
"""
import tensorflow as tf
import memory
FLAGS = tf.flags.FLAGS
class BasicClassifier(object):
def __init__(self, output_dim):
self.output_dim = output_dim
def core_builder(self, memory_val, x, y):
del x, y
y_pred = memory_val
loss = 0.0
return loss, y_pred
class LeNet(object):
"""Standard CNN architecture."""
def __init__(self, image_size, num_channels, hidden_dim):
self.image_size = image_size
self.num_channels = num_channels
self.hidden_dim = hidden_dim
self.matrix_init = tf.truncated_normal_initializer(stddev=0.1)
self.vector_init = tf.constant_initializer(0.0)
def core_builder(self, x):
"""Embeds x using standard CNN architecture.
Args:
x: Batch of images as a 2-d Tensor [batch_size, -1].
Returns:
A 2-d Tensor [batch_size, hidden_dim] of embedded images.
"""
ch1 = 32 * 2 # number of channels in 1st layer
ch2 = 64 * 2 # number of channels in 2nd layer
conv1_weights = tf.get_variable('conv1_w',
[3, 3, self.num_channels, ch1],
initializer=self.matrix_init)
conv1_biases = tf.get_variable('conv1_b', [ch1],
initializer=self.vector_init)
conv1a_weights = tf.get_variable('conv1a_w',
[3, 3, ch1, ch1],
initializer=self.matrix_init)
conv1a_biases = tf.get_variable('conv1a_b', [ch1],
initializer=self.vector_init)
conv2_weights = tf.get_variable('conv2_w', [3, 3, ch1, ch2],
initializer=self.matrix_init)
conv2_biases = tf.get_variable('conv2_b', [ch2],
initializer=self.vector_init)
conv2a_weights = tf.get_variable('conv2a_w', [3, 3, ch2, ch2],
initializer=self.matrix_init)
conv2a_biases = tf.get_variable('conv2a_b', [ch2],
initializer=self.vector_init)
# fully connected
fc1_weights = tf.get_variable(
'fc1_w', [self.image_size // 4 * self.image_size // 4 * ch2,
self.hidden_dim], initializer=self.matrix_init)
fc1_biases = tf.get_variable('fc1_b', [self.hidden_dim],
initializer=self.vector_init)
# define model
x = tf.reshape(x,
[-1, self.image_size, self.image_size, self.num_channels])
batch_size = tf.shape(x)[0]
conv1 = tf.nn.conv2d(x, conv1_weights,
strides=[1, 1, 1, 1], padding='SAME')
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases))
conv1 = tf.nn.conv2d(relu1, conv1a_weights,
strides=[1, 1, 1, 1], padding='SAME')
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1a_biases))
pool1 = tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='SAME')
conv2 = tf.nn.conv2d(pool1, conv2_weights,
strides=[1, 1, 1, 1], padding='SAME')
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases))
conv2 = tf.nn.conv2d(relu2, conv2a_weights,
strides=[1, 1, 1, 1], padding='SAME')
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2a_biases))
pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='SAME')
reshape = tf.reshape(pool2, [batch_size, -1])
hidden = tf.matmul(reshape, fc1_weights) + fc1_biases
return hidden
class Model(object):
"""Model for coordinating between CNN embedder and Memory module."""
def __init__(self, input_dim, output_dim, rep_dim, memory_size, vocab_size,
learning_rate=0.0001, use_lsh=False):
self.input_dim = input_dim
self.output_dim = output_dim
self.rep_dim = rep_dim
self.memory_size = memory_size
self.vocab_size = vocab_size
self.learning_rate = learning_rate
self.use_lsh = use_lsh
self.embedder = self.get_embedder()
self.memory = self.get_memory()
self.classifier = self.get_classifier()
self.global_step = tf.contrib.framework.get_or_create_global_step()
def get_embedder(self):
return LeNet(int(self.input_dim ** 0.5), 1, self.rep_dim)
def get_memory(self):
cls = memory.LSHMemory if self.use_lsh else memory.Memory
return cls(self.rep_dim, self.memory_size, self.vocab_size)
def get_classifier(self):
return BasicClassifier(self.output_dim)
def core_builder(self, x, y, keep_prob, use_recent_idx=True):
embeddings = self.embedder.core_builder(x)
if keep_prob < 1.0:
embeddings = tf.nn.dropout(embeddings, keep_prob)
memory_val, _, teacher_loss = self.memory.query(
embeddings, y, use_recent_idx=use_recent_idx)
loss, y_pred = self.classifier.core_builder(memory_val, x, y)
return loss + teacher_loss, y_pred
def train(self, x, y):
loss, _ = self.core_builder(x, y, keep_prob=0.3)
gradient_ops = self.training_ops(loss)
return loss, gradient_ops
def eval(self, x, y):
_, y_preds = self.core_builder(x, y, keep_prob=1.0,
use_recent_idx=False)
return y_preds
def get_xy_placeholders(self):
return (tf.placeholder(tf.float32, [None, self.input_dim]),
tf.placeholder(tf.int32, [None]))
def setup(self):
"""Sets up all components of the computation graph."""
self.x, self.y = self.get_xy_placeholders()
with tf.variable_scope('core', reuse=None):
self.loss, self.gradient_ops = self.train(self.x, self.y)
with tf.variable_scope('core', reuse=True):
self.y_preds = self.eval(self.x, self.y)
# setup memory "reset" ops
(self.mem_keys, self.mem_vals,
self.mem_age, self.recent_idx) = self.memory.get()
self.mem_keys_reset = tf.placeholder(self.mem_keys.dtype,
tf.identity(self.mem_keys).shape)
self.mem_vals_reset = tf.placeholder(self.mem_vals.dtype,
tf.identity(self.mem_vals).shape)
self.mem_age_reset = tf.placeholder(self.mem_age.dtype,
tf.identity(self.mem_age).shape)
self.recent_idx_reset = tf.placeholder(self.recent_idx.dtype,
tf.identity(self.recent_idx).shape)
self.mem_reset_op = self.memory.set(self.mem_keys_reset,
self.mem_vals_reset,
self.mem_age_reset,
None)
def training_ops(self, loss):
opt = self.get_optimizer()
params = tf.trainable_variables()
gradients = tf.gradients(loss, params)
clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
return opt.apply_gradients(zip(clipped_gradients, params),
global_step=self.global_step)
def get_optimizer(self):
return tf.train.AdamOptimizer(learning_rate=self.learning_rate,
epsilon=1e-4)
def one_step(self, sess, x, y):
outputs = [self.loss, self.gradient_ops]
return sess.run(outputs, feed_dict={self.x: x, self.y: y})
def episode_step(self, sess, x, y, clear_memory=False):
"""Performs training steps on episodic input.
Args:
sess: A Tensorflow Session.
x: A list of batches of images defining the episode.
y: A list of batches of labels corresponding to x.
clear_memory: Whether to clear the memory before the episode.
Returns:
List of losses the same length as the episode.
"""
outputs = [self.loss, self.gradient_ops]
if clear_memory:
self.clear_memory(sess)
losses = []
for xx, yy in zip(x, y):
out = sess.run(outputs, feed_dict={self.x: xx, self.y: yy})
loss = out[0]
losses.append(loss)
return losses
def predict(self, sess, x, y=None):
"""Predict the labels on a single batch of examples.
Args:
sess: A Tensorflow Session.
x: A batch of images.
y: The labels for the images in x.
This allows for updating the memory.
Returns:
Predicted y.
"""
cur_memory = sess.run([self.mem_keys, self.mem_vals,
self.mem_age])
outputs = [self.y_preds]
if y is None:
ret = sess.run(outputs, feed_dict={self.x: x})
else:
ret = sess.run(outputs, feed_dict={self.x: x, self.y: y})
sess.run([self.mem_reset_op],
feed_dict={self.mem_keys_reset: cur_memory[0],
self.mem_vals_reset: cur_memory[1],
self.mem_age_reset: cur_memory[2]})
return ret
def episode_predict(self, sess, x, y, clear_memory=False):
"""Predict the labels on an episode of examples.
Args:
sess: A Tensorflow Session.
x: A list of batches of images.
y: A list of labels for the images in x.
This allows for updating the memory.
clear_memory: Whether to clear the memory before the episode.
Returns:
List of predicted y.
"""
cur_memory = sess.run([self.mem_keys, self.mem_vals,
self.mem_age])
if clear_memory:
self.clear_memory(sess)
outputs = [self.y_preds]
y_preds = []
for xx, yy in zip(x, y):
out = sess.run(outputs, feed_dict={self.x: xx, self.y: yy})
y_pred = out[0]
y_preds.append(y_pred)
sess.run([self.mem_reset_op],
feed_dict={self.mem_keys_reset: cur_memory[0],
self.mem_vals_reset: cur_memory[1],
self.mem_age_reset: cur_memory[2]})
return y_preds
def clear_memory(self, sess):
sess.run([self.memory.clear()])
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
r"""Script for training model.
Simple command to get up and running:
python train.py --memory_size=8192 \
--batch_size=16 --validation_length=50 \
--episode_width=5 --episode_length=30
"""
import logging
import os
import random
import numpy as np
import tensorflow as tf
import data_utils
import model
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_integer('rep_dim', 128,
'dimension of keys to use in memory')
tf.flags.DEFINE_integer('episode_length', 100, 'length of episode')
tf.flags.DEFINE_integer('episode_width', 5,
'number of distinct labels in a single episode')
tf.flags.DEFINE_integer('memory_size', None, 'number of slots in memory. '
'Leave as None to default to episode length')
tf.flags.DEFINE_integer('batch_size', 16, 'batch size')
tf.flags.DEFINE_integer('num_episodes', 100000, 'number of training episodes')
tf.flags.DEFINE_integer('validation_frequency', 20,
'every so many training episodes, '
'assess validation accuracy')
tf.flags.DEFINE_integer('validation_length', 10,
'number of episodes to use to compute '
'validation accuracy')
tf.flags.DEFINE_integer('seed', 888, 'random seed for training sampling')
tf.flags.DEFINE_string('save_dir', '', 'directory to save model to')
tf.flags.DEFINE_bool('use_lsh', False,
'use locality-sensitive hashing '
'(NOTE: not fully tested)')
class Trainer(object):
"""Class that takes care of training, validating, and checkpointing model."""
def __init__(self, train_data, valid_data, input_dim, output_dim=None):
self.train_data = train_data
self.valid_data = valid_data
self.input_dim = input_dim
self.rep_dim = FLAGS.rep_dim
self.episode_length = FLAGS.episode_length
self.episode_width = FLAGS.episode_width
self.batch_size = FLAGS.batch_size
self.memory_size = (self.episode_length * self.batch_size
if FLAGS.memory_size is None else FLAGS.memory_size)
self.use_lsh = FLAGS.use_lsh
self.output_dim = (output_dim if output_dim is not None
else self.episode_width)
def get_model(self):
# vocab size is the number of distinct values that
# could go into the memory key-value storage
vocab_size = self.episode_width * self.batch_size
return model.Model(
self.input_dim, self.output_dim, self.rep_dim, self.memory_size,
vocab_size, use_lsh=self.use_lsh)
def sample_episode_batch(self, data,
episode_length, episode_width, batch_size):
"""Generates a random batch for training or validation.
Structures each element of the batch as an 'episode'.
Each episode contains episode_length examples and
episode_width distinct labels.
Args:
data: A dictionary mapping label to list of examples.
episode_length: Number of examples in each episode.
episode_width: Distinct number of labels in each episode.
batch_size: Batch size (number of episodes).
Returns:
A tuple (x, y) where x is a list of batches of examples
with size episode_length and y is a list of batches of labels.
"""
episodes_x = [[] for _ in xrange(episode_length)]
episodes_y = [[] for _ in xrange(episode_length)]
assert len(data) >= episode_width
keys = data.keys()
for b in xrange(batch_size):
episode_labels = random.sample(keys, episode_width)
remainder = episode_length % episode_width
remainders = [0] * (episode_width - remainder) + [1] * remainder
episode_x = [
random.sample(data[lab],
r + (episode_length - remainder) / episode_width)
for lab, r in zip(episode_labels, remainders)]
episode = sum([[(x, i, ii) for ii, x in enumerate(xx)]
for i, xx in enumerate(episode_x)], [])
random.shuffle(episode)
# Arrange episode so that each distinct label is seen before moving to
# 2nd showing
episode.sort(key=lambda elem: elem[2])
assert len(episode) == episode_length
for i in xrange(episode_length):
episodes_x[i].append(episode[i][0])
episodes_y[i].append(episode[i][1] + b * episode_width)
return ([np.array(xx).astype('float32') for xx in episodes_x],
[np.array(yy).astype('int32') for yy in episodes_y])
def compute_correct(self, ys, y_preds):
return np.mean(np.equal(y_preds, np.array(ys)))
def individual_compute_correct(self, y, y_pred):
return y_pred == y
def run(self):
"""Performs training.
Trains a model using episodic training.
Every so often, runs some evaluations on validation data.
"""
train_data, valid_data = self.train_data, self.valid_data
input_dim, output_dim = self.input_dim, self.output_dim
rep_dim, episode_length = self.rep_dim, self.episode_length
episode_width, memory_size = self.episode_width, self.memory_size
batch_size = self.batch_size
train_size = len(train_data)
valid_size = len(valid_data)
logging.info('train_size (number of labels) %d', train_size)
logging.info('valid_size (number of labels) %d', valid_size)
logging.info('input_dim %d', input_dim)
logging.info('output_dim %d', output_dim)
logging.info('rep_dim %d', rep_dim)
logging.info('episode_length %d', episode_length)
logging.info('episode_width %d', episode_width)
logging.info('memory_size %d', memory_size)
logging.info('batch_size %d', batch_size)
assert all(len(v) >= float(episode_length) / episode_width
for v in train_data.itervalues())
assert all(len(v) >= float(episode_length) / episode_width
for v in valid_data.itervalues())
output_dim = episode_width
self.model = self.get_model()
self.model.setup()
sess = tf.Session()
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver(max_to_keep=10)
ckpt = None
if FLAGS.save_dir:
ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir)
if ckpt and ckpt.model_checkpoint_path:
logging.info('restoring from %s', ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
logging.info('starting now')
losses = []
random.seed(FLAGS.seed)
np.random.seed(FLAGS.seed)
for i in xrange(FLAGS.num_episodes):
x, y = self.sample_episode_batch(
train_data, episode_length, episode_width, batch_size)
outputs = self.model.episode_step(sess, x, y, clear_memory=True)
loss = outputs
losses.append(loss)
if i % FLAGS.validation_frequency == 0:
logging.info('episode batch %d, avg train loss %f',
i, np.mean(losses))
losses = []
# validation
correct = []
correct_by_shot = dict((k, []) for k in xrange(self.episode_width + 1))
for _ in xrange(FLAGS.validation_length):
x, y = self.sample_episode_batch(
valid_data, episode_length, episode_width, 1)
outputs = self.model.episode_predict(
sess, x, y, clear_memory=True)
y_preds = outputs
correct.append(self.compute_correct(np.array(y), y_preds))
# compute per-shot accuracies
seen_counts = [[0] * episode_width for _ in xrange(batch_size)]
# loop over episode steps
for yy, yy_preds in zip(y, y_preds):
# loop over batch examples
for k, (yyy, yyy_preds) in enumerate(zip(yy, yy_preds)):
yyy, yyy_preds = int(yyy), int(yyy_preds)
count = seen_counts[k][yyy % self.episode_width]
if count in correct_by_shot:
correct_by_shot[count].append(
self.individual_compute_correct(yyy, yyy_preds))
seen_counts[k][yyy % self.episode_width] = count + 1
logging.info('validation overall accuracy %f', np.mean(correct))
logging.info('%d-shot: %.3f, ' * (self.episode_width + 1),
*sum([[k, np.mean(correct_by_shot[k])]
for k in xrange(self.episode_width + 1)], []))
if saver and FLAGS.save_dir:
saved_file = saver.save(sess,
os.path.join(FLAGS.save_dir, 'model.ckpt'),
global_step=self.model.global_step)
logging.info('saved model to %s', saved_file)
def main(unused_argv):
train_data, valid_data = data_utils.get_data()
trainer = Trainer(train_data, valid_data, data_utils.IMAGE_NEW_SIZE ** 2)
trainer.run()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment