Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6a9c0da9
Commit
6a9c0da9
authored
Mar 01, 2017
by
Ofir Nachum
Browse files
add learning to remember rare events
parent
bc70271a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1231 additions
and
0 deletions
+1231
-0
learning_to_remember_rare_events/README.md
learning_to_remember_rare_events/README.md
+55
-0
learning_to_remember_rare_events/data_utils.py
learning_to_remember_rare_events/data_utils.py
+242
-0
learning_to_remember_rare_events/memory.py
learning_to_remember_rare_events/memory.py
+385
-0
learning_to_remember_rare_events/model.py
learning_to_remember_rare_events/model.py
+308
-0
learning_to_remember_rare_events/train.py
learning_to_remember_rare_events/train.py
+241
-0
No files found.
learning_to_remember_rare_events/README.md
0 → 100644
View file @
6a9c0da9
Code for the Memory Module as described
in "Learning to Remember Rare Events" by
Lukasz Kaiser, Ofir Nachum, Aurko Roy, and Samy Bengio
published as a conference paper at ICLR 2017.
Requirements:
*
TensorFlow (see tensorflow.org for how to install)
*
Some basic command-line utilities (git, unzip).
Description:
The general memory module is located in memory.py.
Some code is provided to see the memory module in
action on the standard Omniglot dataset.
Download and setup the dataset using data_utils.py
and then run the training script train.py
(see example commands below).
Note that the structure and parameters of the model
are optimized for the data preparation as provided.
Quick Start:
First download and set-up Omniglot data by running
```
python data_utils.py
```
Then run the training script:
```
python train.py --memory_size=8192 \
--batch_size=16 --validation_length=50 \
--episode_width=5 --episode_length=30
```
The first validation batch may look like this (although it is noisy):
```
0-shot: 0.040, 1-shot: 0.404, 2-shot: 0.516, 3-shot: 0.604,
4-shot: 0.656, 5-shot: 0.684
```
At step 500 you may see something like this:
```
0-shot: 0.036, 1-shot: 0.836, 2-shot: 0.900, 3-shot: 0.940,
4-shot: 0.944, 5-shot: 0.916
```
At step 4000 you may see something like this:
```
0-shot: 0.044, 1-shot: 0.960, 2-shot: 1.000, 3-shot: 0.988,
4-shot: 0.972, 5-shot: 0.992
```
Maintained by Ofir Nachum (ofirnachum) and
Lukasz Kaiser (lukaszkaiser).
learning_to_remember_rare_events/data_utils.py
0 → 100644
View file @
6a9c0da9
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""Data loading and other utilities.
Use this file to first copy over and pre-process the Omniglot dataset.
Simply call
python data_utils.py
"""
import
cPickle
as
pickle
import
logging
import
os
import
subprocess
import
numpy
as
np
from
scipy.misc
import
imresize
from
scipy.misc
import
imrotate
from
scipy.ndimage
import
imread
import
tensorflow
as
tf
MAIN_DIR
=
''
REPO_LOCATION
=
'https://github.com/brendenlake/omniglot.git'
REPO_DIR
=
os
.
path
.
join
(
MAIN_DIR
,
'omniglot'
)
DATA_DIR
=
os
.
path
.
join
(
REPO_DIR
,
'python'
)
TRAIN_DIR
=
os
.
path
.
join
(
DATA_DIR
,
'images_background'
)
TEST_DIR
=
os
.
path
.
join
(
DATA_DIR
,
'images_evaluation'
)
DATA_FILE_FORMAT
=
os
.
path
.
join
(
MAIN_DIR
,
'%s_omni.pkl'
)
TRAIN_ROTATIONS
=
True
# augment training data with rotations
TEST_ROTATIONS
=
False
# augment testing data with rotations
IMAGE_ORIGINAL_SIZE
=
105
IMAGE_NEW_SIZE
=
28
def
get_data
():
"""Get data in form suitable for episodic training.
Returns:
Train and test data as dictionaries mapping
label to list of examples.
"""
with
tf
.
gfile
.
GFile
(
DATA_FILE_FORMAT
%
'train'
)
as
f
:
processed_train_data
=
pickle
.
load
(
f
)
with
tf
.
gfile
.
GFile
(
DATA_FILE_FORMAT
%
'test'
)
as
f
:
processed_test_data
=
pickle
.
load
(
f
)
train_data
=
{}
test_data
=
{}
for
data
,
processed_data
in
zip
([
train_data
,
test_data
],
[
processed_train_data
,
processed_test_data
]):
for
image
,
label
in
zip
(
processed_data
[
'images'
],
processed_data
[
'labels'
]):
if
label
not
in
data
:
data
[
label
]
=
[]
data
[
label
].
append
(
image
.
reshape
([
-
1
]).
astype
(
'float32'
))
intersection
=
set
(
train_data
.
keys
())
&
set
(
test_data
.
keys
())
assert
not
intersection
,
'Train and test data intersect.'
ok_num_examples
=
[
len
(
ll
)
==
20
for
_
,
ll
in
train_data
.
iteritems
()]
assert
all
(
ok_num_examples
),
'Bad number of examples in train data.'
ok_num_examples
=
[
len
(
ll
)
==
20
for
_
,
ll
in
test_data
.
iteritems
()]
assert
all
(
ok_num_examples
),
'Bad number of examples in test data.'
logging
.
info
(
'Number of labels in train data: %d.'
,
len
(
train_data
))
logging
.
info
(
'Number of labels in test data: %d.'
,
len
(
test_data
))
return
train_data
,
test_data
def
crawl_directory
(
directory
,
augment_with_rotations
=
False
,
first_label
=
0
):
"""Crawls data directory and returns stuff."""
label_idx
=
first_label
images
=
[]
labels
=
[]
info
=
[]
# traverse root directory
for
root
,
_
,
files
in
os
.
walk
(
directory
):
logging
.
info
(
'Reading files from %s'
,
root
)
fileflag
=
0
for
file_name
in
files
:
full_file_name
=
os
.
path
.
join
(
root
,
file_name
)
img
=
imread
(
full_file_name
,
flatten
=
True
)
for
i
,
angle
in
enumerate
([
0
,
90
,
180
,
270
]):
if
not
augment_with_rotations
and
i
>
0
:
break
images
.
append
(
imrotate
(
img
,
angle
))
labels
.
append
(
label_idx
+
i
)
info
.
append
(
full_file_name
)
fileflag
=
1
if
fileflag
:
label_idx
+=
4
if
augment_with_rotations
else
1
return
images
,
labels
,
info
def
resize_images
(
images
,
new_width
,
new_height
):
"""Resize images to new dimensions."""
resized_images
=
np
.
zeros
([
images
.
shape
[
0
],
new_width
,
new_height
],
dtype
=
np
.
float32
)
for
i
in
range
(
images
.
shape
[
0
]):
resized_images
[
i
,
:,
:]
=
imresize
(
images
[
i
,
:,
:],
[
new_width
,
new_height
],
interp
=
'bilinear'
,
mode
=
None
)
return
resized_images
def
write_datafiles
(
directory
,
write_file
,
resize
=
True
,
rotate
=
False
,
new_width
=
IMAGE_NEW_SIZE
,
new_height
=
IMAGE_NEW_SIZE
,
first_label
=
0
):
"""Load and preprocess images from a directory and write them to a file.
Args:
directory: Directory of alphabet sub-directories.
write_file: Filename to write to.
resize: Whether to resize the images.
rotate: Whether to augment the dataset with rotations.
new_width: New resize width.
new_height: New resize height.
first_label: Label to start with.
Returns:
Number of new labels created.
"""
# these are the default sizes for Omniglot:
imgwidth
=
IMAGE_ORIGINAL_SIZE
imgheight
=
IMAGE_ORIGINAL_SIZE
logging
.
info
(
'Reading the data.'
)
images
,
labels
,
info
=
crawl_directory
(
directory
,
augment_with_rotations
=
rotate
,
first_label
=
first_label
)
images_np
=
np
.
zeros
([
len
(
images
),
imgwidth
,
imgheight
],
dtype
=
np
.
bool
)
labels_np
=
np
.
zeros
([
len
(
labels
)],
dtype
=
np
.
uint32
)
for
i
in
xrange
(
len
(
images
)):
images_np
[
i
,
:,
:]
=
images
[
i
]
labels_np
[
i
]
=
labels
[
i
]
if
resize
:
logging
.
info
(
'Resizing images.'
)
resized_images
=
resize_images
(
images_np
,
new_width
,
new_height
)
logging
.
info
(
'Writing resized data in float32 format.'
)
data
=
{
'images'
:
resized_images
,
'labels'
:
labels_np
,
'info'
:
info
}
with
tf
.
gfile
.
GFile
(
write_file
,
'w'
)
as
f
:
pickle
.
dump
(
data
,
f
)
else
:
logging
.
info
(
'Writing original sized data in boolean format.'
)
data
=
{
'images'
:
images_np
,
'labels'
:
labels_np
,
'info'
:
info
}
with
tf
.
gfile
.
GFile
(
write_file
,
'w'
)
as
f
:
pickle
.
dump
(
data
,
f
)
return
len
(
np
.
unique
(
labels_np
))
def
maybe_download_data
():
"""Download Omniglot repo if it does not exist."""
if
os
.
path
.
exists
(
REPO_DIR
):
logging
.
info
(
'It appears that Git repo already exists.'
)
else
:
logging
.
info
(
'It appears that Git repo does not exist.'
)
logging
.
info
(
'Cloning now.'
)
subprocess
.
check_output
(
'git clone %s'
%
REPO_LOCATION
,
shell
=
True
)
if
os
.
path
.
exists
(
TRAIN_DIR
):
logging
.
info
(
'It appears that train data has already been unzipped.'
)
else
:
logging
.
info
(
'It appears that train data has not been unzipped.'
)
logging
.
info
(
'Unzipping now.'
)
subprocess
.
check_output
(
'unzip %s.zip -d %s'
%
(
TRAIN_DIR
,
DATA_DIR
),
shell
=
True
)
if
os
.
path
.
exists
(
TEST_DIR
):
logging
.
info
(
'It appears that test data has already been unzipped.'
)
else
:
logging
.
info
(
'It appears that test data has not been unzipped.'
)
logging
.
info
(
'Unzipping now.'
)
subprocess
.
check_output
(
'unzip %s.zip -d %s'
%
(
TEST_DIR
,
DATA_DIR
),
shell
=
True
)
def
preprocess_omniglot
():
"""Download and prepare raw Omniglot data.
Downloads the data from GitHub if it does not exist.
Then load the images, augment with rotations if desired.
Resize the images and write them to a pickle file.
"""
maybe_download_data
()
directory
=
TRAIN_DIR
write_file
=
DATA_FILE_FORMAT
%
'train'
num_labels
=
write_datafiles
(
directory
,
write_file
,
resize
=
True
,
rotate
=
TRAIN_ROTATIONS
,
new_width
=
IMAGE_NEW_SIZE
,
new_height
=
IMAGE_NEW_SIZE
)
directory
=
TEST_DIR
write_file
=
DATA_FILE_FORMAT
%
'test'
write_datafiles
(
directory
,
write_file
,
resize
=
True
,
rotate
=
TEST_ROTATIONS
,
new_width
=
IMAGE_NEW_SIZE
,
new_height
=
IMAGE_NEW_SIZE
,
first_label
=
num_labels
)
def
main
(
unused_argv
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
preprocess_omniglot
()
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
learning_to_remember_rare_events/memory.py
0 → 100644
View file @
6a9c0da9
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""Memory module for storing "nearest neighbors".
Implements a key-value memory for generalized one-shot learning
as described in the paper
"Learning to Remember Rare Events"
by Lukasz Kaiser, Ofir Nachum, Aurko Roy, Samy Bengio,
published as a conference paper at ICLR 2017.
"""
import
numpy
as
np
import
tensorflow
as
tf
class
Memory
(
object
):
"""Memory module."""
def
__init__
(
self
,
key_dim
,
memory_size
,
vocab_size
,
choose_k
=
256
,
alpha
=
0.1
,
correct_in_top
=
1
,
age_noise
=
8.0
,
var_cache_device
=
''
,
nn_device
=
''
):
self
.
key_dim
=
key_dim
self
.
memory_size
=
memory_size
self
.
vocab_size
=
vocab_size
self
.
choose_k
=
min
(
choose_k
,
memory_size
)
self
.
alpha
=
alpha
self
.
correct_in_top
=
correct_in_top
self
.
age_noise
=
age_noise
self
.
var_cache_device
=
var_cache_device
# Variables are cached here.
self
.
nn_device
=
nn_device
# Device to perform nearest neighbour matmul.
caching_device
=
var_cache_device
if
var_cache_device
else
None
self
.
update_memory
=
tf
.
constant
(
True
)
# Can be fed "false" if needed.
self
.
mem_keys
=
tf
.
get_variable
(
'memkeys'
,
[
self
.
memory_size
,
self
.
key_dim
],
trainable
=
False
,
initializer
=
tf
.
random_uniform_initializer
(
-
0.0
,
0.0
),
caching_device
=
caching_device
)
self
.
mem_vals
=
tf
.
get_variable
(
'memvals'
,
[
self
.
memory_size
],
dtype
=
tf
.
int32
,
trainable
=
False
,
initializer
=
tf
.
constant_initializer
(
0
,
tf
.
int32
),
caching_device
=
caching_device
)
self
.
mem_age
=
tf
.
get_variable
(
'memage'
,
[
self
.
memory_size
],
dtype
=
tf
.
float32
,
trainable
=
False
,
initializer
=
tf
.
constant_initializer
(
0.0
),
caching_device
=
caching_device
)
self
.
recent_idx
=
tf
.
get_variable
(
'recent_idx'
,
[
self
.
vocab_size
],
dtype
=
tf
.
int32
,
trainable
=
False
,
initializer
=
tf
.
constant_initializer
(
0
,
tf
.
int32
))
# variable for projecting query vector into memory key
self
.
query_proj
=
tf
.
get_variable
(
'memory_query_proj'
,
[
self
.
key_dim
,
self
.
key_dim
],
dtype
=
tf
.
float32
,
initializer
=
tf
.
truncated_normal_initializer
(
0
,
0.01
),
caching_device
=
caching_device
)
def
get
(
self
):
return
self
.
mem_keys
,
self
.
mem_vals
,
self
.
mem_age
,
self
.
recent_idx
def
set
(
self
,
k
,
v
,
a
,
r
=
None
):
return
tf
.
group
(
self
.
mem_keys
.
assign
(
k
),
self
.
mem_vals
.
assign
(
v
),
self
.
mem_age
.
assign
(
a
),
(
self
.
recent_idx
.
assign
(
r
)
if
r
is
not
None
else
tf
.
group
()))
def
clear
(
self
):
return
tf
.
variables_initializer
([
self
.
mem_keys
,
self
.
mem_vals
,
self
.
mem_age
,
self
.
recent_idx
])
def
get_hint_pool_idxs
(
self
,
normalized_query
):
"""Get small set of idxs to compute nearest neighbor queries on.
This is an expensive look-up on the whole memory that is used to
avoid more expensive operations later on.
Args:
normalized_query: A Tensor of shape [None, key_dim].
Returns:
A Tensor of shape [None, choose_k] of indices in memory
that are closest to the queries.
"""
# look up in large memory, no gradients
with
tf
.
device
(
self
.
nn_device
):
similarities
=
tf
.
matmul
(
tf
.
stop_gradient
(
normalized_query
),
self
.
mem_keys
,
transpose_b
=
True
,
name
=
'nn_mmul'
)
_
,
hint_pool_idxs
=
tf
.
nn
.
top_k
(
tf
.
stop_gradient
(
similarities
),
k
=
self
.
choose_k
,
name
=
'nn_topk'
)
return
hint_pool_idxs
def
make_update_op
(
self
,
upd_idxs
,
upd_keys
,
upd_vals
,
batch_size
,
use_recent_idx
,
intended_output
):
"""Function that creates all the update ops."""
mem_age_incr
=
self
.
mem_age
.
assign_add
(
tf
.
ones
([
self
.
memory_size
],
dtype
=
tf
.
float32
))
with
tf
.
control_dependencies
([
mem_age_incr
]):
mem_age_upd
=
tf
.
scatter_update
(
self
.
mem_age
,
upd_idxs
,
tf
.
zeros
([
batch_size
],
dtype
=
tf
.
float32
))
mem_key_upd
=
tf
.
scatter_update
(
self
.
mem_keys
,
upd_idxs
,
upd_keys
)
mem_val_upd
=
tf
.
scatter_update
(
self
.
mem_vals
,
upd_idxs
,
upd_vals
)
if
use_recent_idx
:
recent_idx_upd
=
tf
.
scatter_update
(
self
.
recent_idx
,
intended_output
,
upd_idxs
)
else
:
recent_idx_upd
=
tf
.
group
()
return
tf
.
group
(
mem_age_upd
,
mem_key_upd
,
mem_val_upd
,
recent_idx_upd
)
def
query
(
self
,
query_vec
,
intended_output
,
use_recent_idx
=
True
):
"""Queries memory for nearest neighbor.
Args:
query_vec: A batch of vectors to query (embedding of input to model).
intended_output: The values that would be the correct output of the
memory.
use_recent_idx: Whether to always insert at least one instance of a
correct memory fetch.
Returns:
A tuple (result, mask, teacher_loss).
result: The result of the memory look up.
mask: The affinity of the query to the result.
teacher_loss: The loss for training the memory module.
"""
batch_size
=
tf
.
shape
(
query_vec
)[
0
]
output_given
=
intended_output
is
not
None
# prepare query for memory lookup
query_vec
=
tf
.
matmul
(
query_vec
,
self
.
query_proj
)
normalized_query
=
tf
.
nn
.
l2_normalize
(
query_vec
,
dim
=
1
)
hint_pool_idxs
=
self
.
get_hint_pool_idxs
(
normalized_query
)
if
output_given
and
use_recent_idx
:
# add at least one correct memory
most_recent_hint_idx
=
tf
.
gather
(
self
.
recent_idx
,
intended_output
)
hint_pool_idxs
=
tf
.
concat
([
hint_pool_idxs
,
tf
.
expand_dims
(
most_recent_hint_idx
,
1
)],
1
)
choose_k
=
tf
.
shape
(
hint_pool_idxs
)[
1
]
with
tf
.
device
(
self
.
var_cache_device
):
# create small memory and look up with gradients
my_mem_keys
=
tf
.
stop_gradient
(
tf
.
gather
(
self
.
mem_keys
,
hint_pool_idxs
,
name
=
'my_mem_keys_gather'
))
similarities
=
tf
.
matmul
(
tf
.
expand_dims
(
normalized_query
,
1
),
my_mem_keys
,
adjoint_b
=
True
,
name
=
'batch_mmul'
)
hint_pool_sims
=
tf
.
squeeze
(
similarities
,
[
1
],
name
=
'hint_pool_sims'
)
hint_pool_mem_vals
=
tf
.
gather
(
self
.
mem_vals
,
hint_pool_idxs
,
name
=
'hint_pool_mem_vals'
)
# Calculate softmax mask on the top-k if requested.
# Softmax temperature. Say we have K elements at dist x and one at (x+a).
# Softmax of the last is e^tm(x+a)/Ke^tm*x + e^tm(x+a) = e^tm*a/K+e^tm*a.
# To make that 20% we'd need to have e^tm*a ~= 0.2K, so tm = log(0.2K)/a.
softmax_temp
=
max
(
1.0
,
np
.
log
(
0.2
*
self
.
choose_k
)
/
self
.
alpha
)
mask
=
tf
.
nn
.
softmax
(
hint_pool_sims
[:,
:
choose_k
-
1
]
*
softmax_temp
)
# prepare hints from the teacher on hint pool
teacher_hints
=
tf
.
to_float
(
tf
.
abs
(
tf
.
expand_dims
(
intended_output
,
1
)
-
hint_pool_mem_vals
))
teacher_hints
=
1.0
-
tf
.
minimum
(
1.0
,
teacher_hints
)
teacher_vals
,
teacher_hint_idxs
=
tf
.
nn
.
top_k
(
hint_pool_sims
*
teacher_hints
,
k
=
1
)
neg_teacher_vals
,
_
=
tf
.
nn
.
top_k
(
hint_pool_sims
*
(
1
-
teacher_hints
),
k
=
1
)
# bring back idxs to full memory
teacher_idxs
=
tf
.
gather
(
tf
.
reshape
(
hint_pool_idxs
,
[
-
1
]),
teacher_hint_idxs
[:,
0
]
+
choose_k
*
tf
.
range
(
batch_size
))
# zero-out teacher_vals if there are no hints
teacher_vals
*=
(
1
-
tf
.
to_float
(
tf
.
equal
(
0.0
,
tf
.
reduce_sum
(
teacher_hints
,
1
))))
# prepare returned values
nearest_neighbor
=
tf
.
to_int32
(
tf
.
argmax
(
hint_pool_sims
[:,
:
choose_k
-
1
],
1
))
no_teacher_idxs
=
tf
.
gather
(
tf
.
reshape
(
hint_pool_idxs
,
[
-
1
]),
nearest_neighbor
+
choose_k
*
tf
.
range
(
batch_size
))
# we'll determine whether to do an update to memory based on whether
# memory was queried correctly
sliced_hints
=
tf
.
slice
(
teacher_hints
,
[
0
,
0
],
[
-
1
,
self
.
correct_in_top
])
incorrect_memory_lookup
=
tf
.
equal
(
0.0
,
tf
.
reduce_sum
(
sliced_hints
,
1
))
# loss based on triplet loss
teacher_loss
=
(
tf
.
nn
.
relu
(
neg_teacher_vals
-
teacher_vals
+
self
.
alpha
)
-
self
.
alpha
)
with
tf
.
device
(
self
.
var_cache_device
):
result
=
tf
.
gather
(
self
.
mem_vals
,
tf
.
reshape
(
no_teacher_idxs
,
[
-
1
]))
# prepare memory updates
update_keys
=
normalized_query
update_vals
=
intended_output
fetched_idxs
=
teacher_idxs
# correctly fetched from memory
with
tf
.
device
(
self
.
var_cache_device
):
fetched_keys
=
tf
.
gather
(
self
.
mem_keys
,
fetched_idxs
,
name
=
'fetched_keys'
)
fetched_vals
=
tf
.
gather
(
self
.
mem_vals
,
fetched_idxs
,
name
=
'fetched_vals'
)
# do memory updates here
fetched_keys_upd
=
update_keys
+
fetched_keys
# Momentum-like update
fetched_keys_upd
=
tf
.
nn
.
l2_normalize
(
fetched_keys_upd
,
dim
=
1
)
# Randomize age a bit, e.g., to select different ones in parallel workers.
mem_age_with_noise
=
self
.
mem_age
+
tf
.
random_uniform
(
[
self
.
memory_size
],
-
self
.
age_noise
,
self
.
age_noise
)
_
,
oldest_idxs
=
tf
.
nn
.
top_k
(
mem_age_with_noise
,
k
=
batch_size
,
sorted
=
False
)
with
tf
.
control_dependencies
([
result
]):
upd_idxs
=
tf
.
where
(
incorrect_memory_lookup
,
oldest_idxs
,
fetched_idxs
)
# upd_idxs = tf.Print(upd_idxs, [upd_idxs], "UPD IDX", summarize=8)
upd_keys
=
tf
.
where
(
incorrect_memory_lookup
,
update_keys
,
fetched_keys_upd
)
upd_vals
=
tf
.
where
(
incorrect_memory_lookup
,
update_vals
,
fetched_vals
)
def
make_update_op
():
return
self
.
make_update_op
(
upd_idxs
,
upd_keys
,
upd_vals
,
batch_size
,
use_recent_idx
,
intended_output
)
update_op
=
tf
.
cond
(
self
.
update_memory
,
make_update_op
,
tf
.
no_op
)
with
tf
.
control_dependencies
([
update_op
]):
result
=
tf
.
identity
(
result
)
mask
=
tf
.
identity
(
mask
)
teacher_loss
=
tf
.
identity
(
teacher_loss
)
return
result
,
mask
,
tf
.
reduce_mean
(
teacher_loss
)
class
LSHMemory
(
Memory
):
"""Memory employing locality sensitive hashing.
Note: Not fully tested.
"""
def
__init__
(
self
,
key_dim
,
memory_size
,
vocab_size
,
choose_k
=
256
,
alpha
=
0.1
,
correct_in_top
=
1
,
age_noise
=
8.0
,
var_cache_device
=
''
,
nn_device
=
''
,
num_hashes
=
None
,
num_libraries
=
None
):
super
(
LSHMemory
,
self
).
__init__
(
key_dim
,
memory_size
,
vocab_size
,
choose_k
=
choose_k
,
alpha
=
alpha
,
correct_in_top
=
1
,
age_noise
=
age_noise
,
var_cache_device
=
var_cache_device
,
nn_device
=
nn_device
)
self
.
num_libraries
=
num_libraries
or
int
(
self
.
choose_k
**
0.5
)
self
.
num_per_hash_slot
=
max
(
1
,
self
.
choose_k
//
self
.
num_libraries
)
self
.
num_hashes
=
(
num_hashes
or
int
(
np
.
log2
(
self
.
memory_size
/
self
.
num_per_hash_slot
)))
self
.
num_hashes
=
min
(
max
(
self
.
num_hashes
,
1
),
20
)
self
.
num_hash_slots
=
2
**
self
.
num_hashes
# hashing vectors
self
.
hash_vecs
=
[
tf
.
get_variable
(
'hash_vecs%d'
%
i
,
[
self
.
num_hashes
,
self
.
key_dim
],
dtype
=
tf
.
float32
,
trainable
=
False
,
initializer
=
tf
.
truncated_normal_initializer
(
0
,
1
))
for
i
in
xrange
(
self
.
num_libraries
)]
# map representing which hash slots map to which mem keys
self
.
hash_slots
=
[
tf
.
get_variable
(
'hash_slots%d'
%
i
,
[
self
.
num_hash_slots
,
self
.
num_per_hash_slot
],
dtype
=
tf
.
int32
,
trainable
=
False
,
initializer
=
tf
.
random_uniform_initializer
(
maxval
=
self
.
memory_size
,
dtype
=
tf
.
int32
))
for
i
in
xrange
(
self
.
num_libraries
)]
def
get
(
self
):
# not implemented
return
self
.
mem_keys
,
self
.
mem_vals
,
self
.
mem_age
,
self
.
recent_idx
def
set
(
self
,
k
,
v
,
a
,
r
=
None
):
# not implemented
return
tf
.
group
(
self
.
mem_keys
.
assign
(
k
),
self
.
mem_vals
.
assign
(
v
),
self
.
mem_age
.
assign
(
a
),
(
self
.
recent_idx
.
assign
(
r
)
if
r
is
not
None
else
tf
.
group
()))
def
clear
(
self
):
return
tf
.
variables_initializer
([
self
.
mem_keys
,
self
.
mem_vals
,
self
.
mem_age
,
self
.
recent_idx
]
+
self
.
hash_slots
)
def
get_hash_slots
(
self
,
query
):
"""Gets hashed-to buckets for batch of queries.
Args:
query: 2-d Tensor of query vectors.
Returns:
A list of hashed-to buckets for each hash function.
"""
binary_hash
=
[
tf
.
less
(
tf
.
matmul
(
query
,
self
.
hash_vecs
[
i
],
transpose_b
=
True
),
0
)
for
i
in
xrange
(
self
.
num_libraries
)]
hash_slot_idxs
=
[
tf
.
reduce_sum
(
tf
.
to_int32
(
binary_hash
[
i
])
*
tf
.
constant
([[
2
**
i
for
i
in
xrange
(
self
.
num_hashes
)]],
dtype
=
tf
.
int32
),
1
)
for
i
in
xrange
(
self
.
num_libraries
)]
return
hash_slot_idxs
def
get_hint_pool_idxs
(
self
,
normalized_query
):
"""Get small set of idxs to compute nearest neighbor queries on.
This is an expensive look-up on the whole memory that is used to
avoid more expensive operations later on.
Args:
normalized_query: A Tensor of shape [None, key_dim].
Returns:
A Tensor of shape [None, choose_k] of indices in memory
that are closest to the queries.
"""
# get hash of query vecs
hash_slot_idxs
=
self
.
get_hash_slots
(
normalized_query
)
# grab mem idxs in the hash slots
hint_pool_idxs
=
[
tf
.
maximum
(
tf
.
minimum
(
tf
.
gather
(
self
.
hash_slots
[
i
],
idxs
),
self
.
memory_size
-
1
),
0
)
for
i
,
idxs
in
enumerate
(
hash_slot_idxs
)]
return
tf
.
concat
(
hint_pool_idxs
,
1
)
def
make_update_op
(
self
,
upd_idxs
,
upd_keys
,
upd_vals
,
batch_size
,
use_recent_idx
,
intended_output
):
"""Function that creates all the update ops."""
base_update_op
=
super
(
LSHMemory
,
self
).
make_update_op
(
upd_idxs
,
upd_keys
,
upd_vals
,
batch_size
,
use_recent_idx
,
intended_output
)
# compute hash slots to be updated
hash_slot_idxs
=
self
.
get_hash_slots
(
upd_keys
)
# make updates
update_ops
=
[]
with
tf
.
control_dependencies
([
base_update_op
]):
for
i
,
slot_idxs
in
enumerate
(
hash_slot_idxs
):
# for each slot, choose which entry to replace
entry_idx
=
tf
.
random_uniform
([
batch_size
],
maxval
=
self
.
num_per_hash_slot
,
dtype
=
tf
.
int32
)
entry_mul
=
1
-
tf
.
one_hot
(
entry_idx
,
self
.
num_per_hash_slot
,
dtype
=
tf
.
int32
)
entry_add
=
(
tf
.
expand_dims
(
upd_idxs
,
1
)
*
tf
.
one_hot
(
entry_idx
,
self
.
num_per_hash_slot
,
dtype
=
tf
.
int32
))
mul_op
=
tf
.
scatter_mul
(
self
.
hash_slots
[
i
],
slot_idxs
,
entry_mul
)
with
tf
.
control_dependencies
([
mul_op
]):
add_op
=
tf
.
scatter_add
(
self
.
hash_slots
[
i
],
slot_idxs
,
entry_add
)
update_ops
.
append
(
add_op
)
return
tf
.
group
(
*
update_ops
)
learning_to_remember_rare_events/model.py
0 → 100644
View file @
6a9c0da9
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""Model using memory component.
The model embeds images using a standard CNN architecture.
These embeddings are used as keys to the memory component,
which returns nearest neighbors.
"""
import
tensorflow
as
tf
import
memory
FLAGS
=
tf
.
flags
.
FLAGS
class
BasicClassifier
(
object
):
def
__init__
(
self
,
output_dim
):
self
.
output_dim
=
output_dim
def
core_builder
(
self
,
memory_val
,
x
,
y
):
del
x
,
y
y_pred
=
memory_val
loss
=
0.0
return
loss
,
y_pred
class
LeNet
(
object
):
"""Standard CNN architecture."""
def
__init__
(
self
,
image_size
,
num_channels
,
hidden_dim
):
self
.
image_size
=
image_size
self
.
num_channels
=
num_channels
self
.
hidden_dim
=
hidden_dim
self
.
matrix_init
=
tf
.
truncated_normal_initializer
(
stddev
=
0.1
)
self
.
vector_init
=
tf
.
constant_initializer
(
0.0
)
def
core_builder
(
self
,
x
):
"""Embeds x using standard CNN architecture.
Args:
x: Batch of images as a 2-d Tensor [batch_size, -1].
Returns:
A 2-d Tensor [batch_size, hidden_dim] of embedded images.
"""
ch1
=
32
*
2
# number of channels in 1st layer
ch2
=
64
*
2
# number of channels in 2nd layer
conv1_weights
=
tf
.
get_variable
(
'conv1_w'
,
[
3
,
3
,
self
.
num_channels
,
ch1
],
initializer
=
self
.
matrix_init
)
conv1_biases
=
tf
.
get_variable
(
'conv1_b'
,
[
ch1
],
initializer
=
self
.
vector_init
)
conv1a_weights
=
tf
.
get_variable
(
'conv1a_w'
,
[
3
,
3
,
ch1
,
ch1
],
initializer
=
self
.
matrix_init
)
conv1a_biases
=
tf
.
get_variable
(
'conv1a_b'
,
[
ch1
],
initializer
=
self
.
vector_init
)
conv2_weights
=
tf
.
get_variable
(
'conv2_w'
,
[
3
,
3
,
ch1
,
ch2
],
initializer
=
self
.
matrix_init
)
conv2_biases
=
tf
.
get_variable
(
'conv2_b'
,
[
ch2
],
initializer
=
self
.
vector_init
)
conv2a_weights
=
tf
.
get_variable
(
'conv2a_w'
,
[
3
,
3
,
ch2
,
ch2
],
initializer
=
self
.
matrix_init
)
conv2a_biases
=
tf
.
get_variable
(
'conv2a_b'
,
[
ch2
],
initializer
=
self
.
vector_init
)
# fully connected
fc1_weights
=
tf
.
get_variable
(
'fc1_w'
,
[
self
.
image_size
//
4
*
self
.
image_size
//
4
*
ch2
,
self
.
hidden_dim
],
initializer
=
self
.
matrix_init
)
fc1_biases
=
tf
.
get_variable
(
'fc1_b'
,
[
self
.
hidden_dim
],
initializer
=
self
.
vector_init
)
# define model
x
=
tf
.
reshape
(
x
,
[
-
1
,
self
.
image_size
,
self
.
image_size
,
self
.
num_channels
])
batch_size
=
tf
.
shape
(
x
)[
0
]
conv1
=
tf
.
nn
.
conv2d
(
x
,
conv1_weights
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
'SAME'
)
relu1
=
tf
.
nn
.
relu
(
tf
.
nn
.
bias_add
(
conv1
,
conv1_biases
))
conv1
=
tf
.
nn
.
conv2d
(
relu1
,
conv1a_weights
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
'SAME'
)
relu1
=
tf
.
nn
.
relu
(
tf
.
nn
.
bias_add
(
conv1
,
conv1a_biases
))
pool1
=
tf
.
nn
.
max_pool
(
relu1
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
)
conv2
=
tf
.
nn
.
conv2d
(
pool1
,
conv2_weights
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
'SAME'
)
relu2
=
tf
.
nn
.
relu
(
tf
.
nn
.
bias_add
(
conv2
,
conv2_biases
))
conv2
=
tf
.
nn
.
conv2d
(
relu2
,
conv2a_weights
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
'SAME'
)
relu2
=
tf
.
nn
.
relu
(
tf
.
nn
.
bias_add
(
conv2
,
conv2a_biases
))
pool2
=
tf
.
nn
.
max_pool
(
relu2
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
)
reshape
=
tf
.
reshape
(
pool2
,
[
batch_size
,
-
1
])
hidden
=
tf
.
matmul
(
reshape
,
fc1_weights
)
+
fc1_biases
return
hidden
class
Model
(
object
):
"""Model for coordinating between CNN embedder and Memory module."""
def
__init__
(
self
,
input_dim
,
output_dim
,
rep_dim
,
memory_size
,
vocab_size
,
learning_rate
=
0.0001
,
use_lsh
=
False
):
self
.
input_dim
=
input_dim
self
.
output_dim
=
output_dim
self
.
rep_dim
=
rep_dim
self
.
memory_size
=
memory_size
self
.
vocab_size
=
vocab_size
self
.
learning_rate
=
learning_rate
self
.
use_lsh
=
use_lsh
self
.
embedder
=
self
.
get_embedder
()
self
.
memory
=
self
.
get_memory
()
self
.
classifier
=
self
.
get_classifier
()
self
.
global_step
=
tf
.
contrib
.
framework
.
get_or_create_global_step
()
def
get_embedder
(
self
):
return
LeNet
(
int
(
self
.
input_dim
**
0.5
),
1
,
self
.
rep_dim
)
def
get_memory
(
self
):
cls
=
memory
.
LSHMemory
if
self
.
use_lsh
else
memory
.
Memory
return
cls
(
self
.
rep_dim
,
self
.
memory_size
,
self
.
vocab_size
)
def
get_classifier
(
self
):
return
BasicClassifier
(
self
.
output_dim
)
def
core_builder
(
self
,
x
,
y
,
keep_prob
,
use_recent_idx
=
True
):
embeddings
=
self
.
embedder
.
core_builder
(
x
)
if
keep_prob
<
1.0
:
embeddings
=
tf
.
nn
.
dropout
(
embeddings
,
keep_prob
)
memory_val
,
_
,
teacher_loss
=
self
.
memory
.
query
(
embeddings
,
y
,
use_recent_idx
=
use_recent_idx
)
loss
,
y_pred
=
self
.
classifier
.
core_builder
(
memory_val
,
x
,
y
)
return
loss
+
teacher_loss
,
y_pred
def
train
(
self
,
x
,
y
):
loss
,
_
=
self
.
core_builder
(
x
,
y
,
keep_prob
=
0.3
)
gradient_ops
=
self
.
training_ops
(
loss
)
return
loss
,
gradient_ops
def
eval
(
self
,
x
,
y
):
_
,
y_preds
=
self
.
core_builder
(
x
,
y
,
keep_prob
=
1.0
,
use_recent_idx
=
False
)
return
y_preds
def
get_xy_placeholders
(
self
):
return
(
tf
.
placeholder
(
tf
.
float32
,
[
None
,
self
.
input_dim
]),
tf
.
placeholder
(
tf
.
int32
,
[
None
]))
def
setup
(
self
):
"""Sets up all components of the computation graph."""
self
.
x
,
self
.
y
=
self
.
get_xy_placeholders
()
with
tf
.
variable_scope
(
'core'
,
reuse
=
None
):
self
.
loss
,
self
.
gradient_ops
=
self
.
train
(
self
.
x
,
self
.
y
)
with
tf
.
variable_scope
(
'core'
,
reuse
=
True
):
self
.
y_preds
=
self
.
eval
(
self
.
x
,
self
.
y
)
# setup memory "reset" ops
(
self
.
mem_keys
,
self
.
mem_vals
,
self
.
mem_age
,
self
.
recent_idx
)
=
self
.
memory
.
get
()
self
.
mem_keys_reset
=
tf
.
placeholder
(
self
.
mem_keys
.
dtype
,
tf
.
identity
(
self
.
mem_keys
).
shape
)
self
.
mem_vals_reset
=
tf
.
placeholder
(
self
.
mem_vals
.
dtype
,
tf
.
identity
(
self
.
mem_vals
).
shape
)
self
.
mem_age_reset
=
tf
.
placeholder
(
self
.
mem_age
.
dtype
,
tf
.
identity
(
self
.
mem_age
).
shape
)
self
.
recent_idx_reset
=
tf
.
placeholder
(
self
.
recent_idx
.
dtype
,
tf
.
identity
(
self
.
recent_idx
).
shape
)
self
.
mem_reset_op
=
self
.
memory
.
set
(
self
.
mem_keys_reset
,
self
.
mem_vals_reset
,
self
.
mem_age_reset
,
None
)
def
training_ops
(
self
,
loss
):
opt
=
self
.
get_optimizer
()
params
=
tf
.
trainable_variables
()
gradients
=
tf
.
gradients
(
loss
,
params
)
clipped_gradients
,
_
=
tf
.
clip_by_global_norm
(
gradients
,
5.0
)
return
opt
.
apply_gradients
(
zip
(
clipped_gradients
,
params
),
global_step
=
self
.
global_step
)
def
get_optimizer
(
self
):
return
tf
.
train
.
AdamOptimizer
(
learning_rate
=
self
.
learning_rate
,
epsilon
=
1e-4
)
def
one_step
(
self
,
sess
,
x
,
y
):
outputs
=
[
self
.
loss
,
self
.
gradient_ops
]
return
sess
.
run
(
outputs
,
feed_dict
=
{
self
.
x
:
x
,
self
.
y
:
y
})
def
episode_step
(
self
,
sess
,
x
,
y
,
clear_memory
=
False
):
"""Performs training steps on episodic input.
Args:
sess: A Tensorflow Session.
x: A list of batches of images defining the episode.
y: A list of batches of labels corresponding to x.
clear_memory: Whether to clear the memory before the episode.
Returns:
List of losses the same length as the episode.
"""
outputs
=
[
self
.
loss
,
self
.
gradient_ops
]
if
clear_memory
:
self
.
clear_memory
(
sess
)
losses
=
[]
for
xx
,
yy
in
zip
(
x
,
y
):
out
=
sess
.
run
(
outputs
,
feed_dict
=
{
self
.
x
:
xx
,
self
.
y
:
yy
})
loss
=
out
[
0
]
losses
.
append
(
loss
)
return
losses
def
predict
(
self
,
sess
,
x
,
y
=
None
):
"""Predict the labels on a single batch of examples.
Args:
sess: A Tensorflow Session.
x: A batch of images.
y: The labels for the images in x.
This allows for updating the memory.
Returns:
Predicted y.
"""
cur_memory
=
sess
.
run
([
self
.
mem_keys
,
self
.
mem_vals
,
self
.
mem_age
])
outputs
=
[
self
.
y_preds
]
if
y
is
None
:
ret
=
sess
.
run
(
outputs
,
feed_dict
=
{
self
.
x
:
x
})
else
:
ret
=
sess
.
run
(
outputs
,
feed_dict
=
{
self
.
x
:
x
,
self
.
y
:
y
})
sess
.
run
([
self
.
mem_reset_op
],
feed_dict
=
{
self
.
mem_keys_reset
:
cur_memory
[
0
],
self
.
mem_vals_reset
:
cur_memory
[
1
],
self
.
mem_age_reset
:
cur_memory
[
2
]})
return
ret
def
episode_predict
(
self
,
sess
,
x
,
y
,
clear_memory
=
False
):
"""Predict the labels on an episode of examples.
Args:
sess: A Tensorflow Session.
x: A list of batches of images.
y: A list of labels for the images in x.
This allows for updating the memory.
clear_memory: Whether to clear the memory before the episode.
Returns:
List of predicted y.
"""
cur_memory
=
sess
.
run
([
self
.
mem_keys
,
self
.
mem_vals
,
self
.
mem_age
])
if
clear_memory
:
self
.
clear_memory
(
sess
)
outputs
=
[
self
.
y_preds
]
y_preds
=
[]
for
xx
,
yy
in
zip
(
x
,
y
):
out
=
sess
.
run
(
outputs
,
feed_dict
=
{
self
.
x
:
xx
,
self
.
y
:
yy
})
y_pred
=
out
[
0
]
y_preds
.
append
(
y_pred
)
sess
.
run
([
self
.
mem_reset_op
],
feed_dict
=
{
self
.
mem_keys_reset
:
cur_memory
[
0
],
self
.
mem_vals_reset
:
cur_memory
[
1
],
self
.
mem_age_reset
:
cur_memory
[
2
]})
return
y_preds
def
clear_memory
(
self
,
sess
):
sess
.
run
([
self
.
memory
.
clear
()])
learning_to_remember_rare_events/train.py
0 → 100644
View file @
6a9c0da9
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
r
"""Script for training model.
Simple command to get up and running:
python train.py --memory_size=8192 \
--batch_size=16 --validation_length=50 \
--episode_width=5 --episode_length=30
"""
import
logging
import
os
import
random
import
numpy
as
np
import
tensorflow
as
tf
import
data_utils
import
model
FLAGS
=
tf
.
flags
.
FLAGS
tf
.
flags
.
DEFINE_integer
(
'rep_dim'
,
128
,
'dimension of keys to use in memory'
)
tf
.
flags
.
DEFINE_integer
(
'episode_length'
,
100
,
'length of episode'
)
tf
.
flags
.
DEFINE_integer
(
'episode_width'
,
5
,
'number of distinct labels in a single episode'
)
tf
.
flags
.
DEFINE_integer
(
'memory_size'
,
None
,
'number of slots in memory. '
'Leave as None to default to episode length'
)
tf
.
flags
.
DEFINE_integer
(
'batch_size'
,
16
,
'batch size'
)
tf
.
flags
.
DEFINE_integer
(
'num_episodes'
,
100000
,
'number of training episodes'
)
tf
.
flags
.
DEFINE_integer
(
'validation_frequency'
,
20
,
'every so many training episodes, '
'assess validation accuracy'
)
tf
.
flags
.
DEFINE_integer
(
'validation_length'
,
10
,
'number of episodes to use to compute '
'validation accuracy'
)
tf
.
flags
.
DEFINE_integer
(
'seed'
,
888
,
'random seed for training sampling'
)
tf
.
flags
.
DEFINE_string
(
'save_dir'
,
''
,
'directory to save model to'
)
tf
.
flags
.
DEFINE_bool
(
'use_lsh'
,
False
,
'use locality-sensitive hashing '
'(NOTE: not fully tested)'
)
class
Trainer
(
object
):
"""Class that takes care of training, validating, and checkpointing model."""
def
__init__
(
self
,
train_data
,
valid_data
,
input_dim
,
output_dim
=
None
):
self
.
train_data
=
train_data
self
.
valid_data
=
valid_data
self
.
input_dim
=
input_dim
self
.
rep_dim
=
FLAGS
.
rep_dim
self
.
episode_length
=
FLAGS
.
episode_length
self
.
episode_width
=
FLAGS
.
episode_width
self
.
batch_size
=
FLAGS
.
batch_size
self
.
memory_size
=
(
self
.
episode_length
*
self
.
batch_size
if
FLAGS
.
memory_size
is
None
else
FLAGS
.
memory_size
)
self
.
use_lsh
=
FLAGS
.
use_lsh
self
.
output_dim
=
(
output_dim
if
output_dim
is
not
None
else
self
.
episode_width
)
def
get_model
(
self
):
# vocab size is the number of distinct values that
# could go into the memory key-value storage
vocab_size
=
self
.
episode_width
*
self
.
batch_size
return
model
.
Model
(
self
.
input_dim
,
self
.
output_dim
,
self
.
rep_dim
,
self
.
memory_size
,
vocab_size
,
use_lsh
=
self
.
use_lsh
)
def
sample_episode_batch
(
self
,
data
,
episode_length
,
episode_width
,
batch_size
):
"""Generates a random batch for training or validation.
Structures each element of the batch as an 'episode'.
Each episode contains episode_length examples and
episode_width distinct labels.
Args:
data: A dictionary mapping label to list of examples.
episode_length: Number of examples in each episode.
episode_width: Distinct number of labels in each episode.
batch_size: Batch size (number of episodes).
Returns:
A tuple (x, y) where x is a list of batches of examples
with size episode_length and y is a list of batches of labels.
"""
episodes_x
=
[[]
for
_
in
xrange
(
episode_length
)]
episodes_y
=
[[]
for
_
in
xrange
(
episode_length
)]
assert
len
(
data
)
>=
episode_width
keys
=
data
.
keys
()
for
b
in
xrange
(
batch_size
):
episode_labels
=
random
.
sample
(
keys
,
episode_width
)
remainder
=
episode_length
%
episode_width
remainders
=
[
0
]
*
(
episode_width
-
remainder
)
+
[
1
]
*
remainder
episode_x
=
[
random
.
sample
(
data
[
lab
],
r
+
(
episode_length
-
remainder
)
/
episode_width
)
for
lab
,
r
in
zip
(
episode_labels
,
remainders
)]
episode
=
sum
([[(
x
,
i
,
ii
)
for
ii
,
x
in
enumerate
(
xx
)]
for
i
,
xx
in
enumerate
(
episode_x
)],
[])
random
.
shuffle
(
episode
)
# Arrange episode so that each distinct label is seen before moving to
# 2nd showing
episode
.
sort
(
key
=
lambda
elem
:
elem
[
2
])
assert
len
(
episode
)
==
episode_length
for
i
in
xrange
(
episode_length
):
episodes_x
[
i
].
append
(
episode
[
i
][
0
])
episodes_y
[
i
].
append
(
episode
[
i
][
1
]
+
b
*
episode_width
)
return
([
np
.
array
(
xx
).
astype
(
'float32'
)
for
xx
in
episodes_x
],
[
np
.
array
(
yy
).
astype
(
'int32'
)
for
yy
in
episodes_y
])
def
compute_correct
(
self
,
ys
,
y_preds
):
return
np
.
mean
(
np
.
equal
(
y_preds
,
np
.
array
(
ys
)))
def
individual_compute_correct
(
self
,
y
,
y_pred
):
return
y_pred
==
y
def
run
(
self
):
"""Performs training.
Trains a model using episodic training.
Every so often, runs some evaluations on validation data.
"""
train_data
,
valid_data
=
self
.
train_data
,
self
.
valid_data
input_dim
,
output_dim
=
self
.
input_dim
,
self
.
output_dim
rep_dim
,
episode_length
=
self
.
rep_dim
,
self
.
episode_length
episode_width
,
memory_size
=
self
.
episode_width
,
self
.
memory_size
batch_size
=
self
.
batch_size
train_size
=
len
(
train_data
)
valid_size
=
len
(
valid_data
)
logging
.
info
(
'train_size (number of labels) %d'
,
train_size
)
logging
.
info
(
'valid_size (number of labels) %d'
,
valid_size
)
logging
.
info
(
'input_dim %d'
,
input_dim
)
logging
.
info
(
'output_dim %d'
,
output_dim
)
logging
.
info
(
'rep_dim %d'
,
rep_dim
)
logging
.
info
(
'episode_length %d'
,
episode_length
)
logging
.
info
(
'episode_width %d'
,
episode_width
)
logging
.
info
(
'memory_size %d'
,
memory_size
)
logging
.
info
(
'batch_size %d'
,
batch_size
)
assert
all
(
len
(
v
)
>=
float
(
episode_length
)
/
episode_width
for
v
in
train_data
.
itervalues
())
assert
all
(
len
(
v
)
>=
float
(
episode_length
)
/
episode_width
for
v
in
valid_data
.
itervalues
())
output_dim
=
episode_width
self
.
model
=
self
.
get_model
()
self
.
model
.
setup
()
sess
=
tf
.
Session
()
sess
.
run
(
tf
.
initialize_all_variables
())
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
10
)
ckpt
=
None
if
FLAGS
.
save_dir
:
ckpt
=
tf
.
train
.
get_checkpoint_state
(
FLAGS
.
save_dir
)
if
ckpt
and
ckpt
.
model_checkpoint_path
:
logging
.
info
(
'restoring from %s'
,
ckpt
.
model_checkpoint_path
)
saver
.
restore
(
sess
,
ckpt
.
model_checkpoint_path
)
logging
.
info
(
'starting now'
)
losses
=
[]
random
.
seed
(
FLAGS
.
seed
)
np
.
random
.
seed
(
FLAGS
.
seed
)
for
i
in
xrange
(
FLAGS
.
num_episodes
):
x
,
y
=
self
.
sample_episode_batch
(
train_data
,
episode_length
,
episode_width
,
batch_size
)
outputs
=
self
.
model
.
episode_step
(
sess
,
x
,
y
,
clear_memory
=
True
)
loss
=
outputs
losses
.
append
(
loss
)
if
i
%
FLAGS
.
validation_frequency
==
0
:
logging
.
info
(
'episode batch %d, avg train loss %f'
,
i
,
np
.
mean
(
losses
))
losses
=
[]
# validation
correct
=
[]
correct_by_shot
=
dict
((
k
,
[])
for
k
in
xrange
(
self
.
episode_width
+
1
))
for
_
in
xrange
(
FLAGS
.
validation_length
):
x
,
y
=
self
.
sample_episode_batch
(
valid_data
,
episode_length
,
episode_width
,
1
)
outputs
=
self
.
model
.
episode_predict
(
sess
,
x
,
y
,
clear_memory
=
True
)
y_preds
=
outputs
correct
.
append
(
self
.
compute_correct
(
np
.
array
(
y
),
y_preds
))
# compute per-shot accuracies
seen_counts
=
[[
0
]
*
episode_width
for
_
in
xrange
(
batch_size
)]
# loop over episode steps
for
yy
,
yy_preds
in
zip
(
y
,
y_preds
):
# loop over batch examples
for
k
,
(
yyy
,
yyy_preds
)
in
enumerate
(
zip
(
yy
,
yy_preds
)):
yyy
,
yyy_preds
=
int
(
yyy
),
int
(
yyy_preds
)
count
=
seen_counts
[
k
][
yyy
%
self
.
episode_width
]
if
count
in
correct_by_shot
:
correct_by_shot
[
count
].
append
(
self
.
individual_compute_correct
(
yyy
,
yyy_preds
))
seen_counts
[
k
][
yyy
%
self
.
episode_width
]
=
count
+
1
logging
.
info
(
'validation overall accuracy %f'
,
np
.
mean
(
correct
))
logging
.
info
(
'%d-shot: %.3f, '
*
(
self
.
episode_width
+
1
),
*
sum
([[
k
,
np
.
mean
(
correct_by_shot
[
k
])]
for
k
in
xrange
(
self
.
episode_width
+
1
)],
[]))
if
saver
and
FLAGS
.
save_dir
:
saved_file
=
saver
.
save
(
sess
,
os
.
path
.
join
(
FLAGS
.
save_dir
,
'model.ckpt'
),
global_step
=
self
.
model
.
global_step
)
logging
.
info
(
'saved model to %s'
,
saved_file
)
def
main
(
unused_argv
):
train_data
,
valid_data
=
data_utils
.
get_data
()
trainer
=
Trainer
(
train_data
,
valid_data
,
data_utils
.
IMAGE_NEW_SIZE
**
2
)
trainer
.
run
()
if
__name__
==
'__main__'
:
logging
.
basicConfig
(
level
=
logging
.
INFO
)
tf
.
app
.
run
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment