Commit 90d6e3b9 authored by Lukasz Kaiser's avatar Lukasz Kaiser
Browse files

Cleaned-up version of the Neural GPU code (runs outside of google3).

parent d84df16b
...@@ -2,4 +2,10 @@ ...@@ -2,4 +2,10 @@
Code for the Neural GPU model as described Code for the Neural GPU model as described
in [[http://arxiv.org/abs/1511.08228]]. in [[http://arxiv.org/abs/1511.08228]].
Requirements:
* TensorFlow (see tensorflow.org for how to install)
* Matplotlib for Python (sudo apt-get install python-matplotlib)
Run: python neural_gpu_trainer.py --task=rev
Maintained by Lukasz Kaiser (lukaszkaiser) Maintained by Lukasz Kaiser (lukaszkaiser)
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==============================================================================
"""Convolutional Gated Recurrent Networks for Algorithm Learning.""" """Convolutional Gated Recurrent Networks for Algorithm Learning."""
import math import math
...@@ -21,12 +5,10 @@ import random ...@@ -21,12 +5,10 @@ import random
import sys import sys
import time import time
import google3
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from google3.third_party.tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
...@@ -162,6 +144,21 @@ def init_data(task, length, nbr_cases, nclass): ...@@ -162,6 +144,21 @@ def init_data(task, length, nbr_cases, nclass):
test_set[task][l].append([inp, target]) test_set[task][l].append([inp, target])
def to_symbol(i):
"""Covert ids to text."""
if i == 0: return ""
if i == 11: return "+"
if i == 12: return "*"
return str(i-1)
def to_id(s):
"""Covert text to ids."""
if s == "+": return 11
if s == "*": return 12
return int(s) + 1
def get_batch(max_length, batch_size, do_train, task, offset=None, preset=None): def get_batch(max_length, batch_size, do_train, task, offset=None, preset=None):
"""Get a batch of data, training or testing.""" """Get a batch of data, training or testing."""
inputs = [] inputs = []
......
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==============================================================================
"""The Neural GPU Model.""" """The Neural GPU Model."""
import time import time
import google3
import tensorflow as tf import tensorflow as tf
from google3.experimental.users.lukaszkaiser.neural_gpu import data_utils import data_utils
def conv_linear(args, kw, kh, nin, nout, do_bias, bias_start, prefix): def conv_linear(args, kw, kh, nin, nout, do_bias, bias_start, prefix):
......
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#==============================================================================
"""Neural GPU for Learning Algorithms.""" """Neural GPU for Learning Algorithms."""
import math import math
...@@ -22,16 +6,15 @@ import random ...@@ -22,16 +6,15 @@ import random
import sys import sys
import time import time
import google3
import matplotlib.animation as anim import matplotlib.animation as anim
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from google3.third_party.tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
import google3.experimental.users.lukaszkaiser.neural_gpu.data_utils as data
import google3.experimental.users.lukaszkaiser.neural_gpu.neural_gpu as ngpu import data_utils as data
import neural_gpu
tf.app.flags.DEFINE_float("lr", 0.1, "Learning rate.") tf.app.flags.DEFINE_float("lr", 0.1, "Learning rate.")
tf.app.flags.DEFINE_float("init_weight", 1.0, "Initial weights deviation.") tf.app.flags.DEFINE_float("init_weight", 1.0, "Initial weights deviation.")
...@@ -39,7 +22,7 @@ tf.app.flags.DEFINE_float("max_grad_norm", 0.05, "Clip gradients to this norm.") ...@@ -39,7 +22,7 @@ tf.app.flags.DEFINE_float("max_grad_norm", 0.05, "Clip gradients to this norm.")
tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.") tf.app.flags.DEFINE_float("cutoff", 1.2, "Cutoff at the gates.")
tf.app.flags.DEFINE_float("pull", 0.0005, "Starting pull of the relaxations.") tf.app.flags.DEFINE_float("pull", 0.0005, "Starting pull of the relaxations.")
tf.app.flags.DEFINE_float("pull_incr", 1.2, "Increase pull by that much.") tf.app.flags.DEFINE_float("pull_incr", 1.2, "Increase pull by that much.")
tf.app.flags.DEFINE_float("dropout", 0.2, "Dropout that much.") tf.app.flags.DEFINE_float("dropout", 0.15, "Dropout that much.")
tf.app.flags.DEFINE_float("grad_noise_scale", 1.0, "Gradient noise scale.") tf.app.flags.DEFINE_float("grad_noise_scale", 1.0, "Gradient noise scale.")
tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size.") tf.app.flags.DEFINE_integer("batch_size", 64, "Batch size.")
tf.app.flags.DEFINE_integer("low_batch_size", 16, "Low batch size.") tf.app.flags.DEFINE_integer("low_batch_size", 16, "Low batch size.")
...@@ -63,6 +46,7 @@ tf.app.flags.DEFINE_string("task", "rev", "Which task are we learning?") ...@@ -63,6 +46,7 @@ tf.app.flags.DEFINE_string("task", "rev", "Which task are we learning?")
tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.") tf.app.flags.DEFINE_string("train_dir", "/tmp/", "Directory to store models.")
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
EXTRA_EVAL = 12
def initialize(sess): def initialize(sess):
...@@ -83,7 +67,7 @@ def initialize(sess): ...@@ -83,7 +67,7 @@ def initialize(sess):
min_length = 3 min_length = 3
max_length = min(FLAGS.max_length, data.bins[-1]) max_length = min(FLAGS.max_length, data.bins[-1])
assert max_length + 1 > min_length assert max_length + 1 > min_length
while len(data.bins) > 1 and data.bins[-2] > max_length + 12: while len(data.bins) > 1 and data.bins[-2] > max_length + EXTRA_EVAL:
data.bins = data.bins[:-1] data.bins = data.bins[:-1]
assert data.bins[0] > FLAGS.rx_step assert data.bins[0] > FLAGS.rx_step
nclass = min(FLAGS.niclass, FLAGS.noclass) nclass = min(FLAGS.niclass, FLAGS.noclass)
...@@ -92,7 +76,7 @@ def initialize(sess): ...@@ -92,7 +76,7 @@ def initialize(sess):
# Initialize data for each task. # Initialize data for each task.
tasks = FLAGS.task.split("-") tasks = FLAGS.task.split("-")
for t in tasks: for t in tasks:
for l in xrange(max_length + 11): for l in xrange(max_length + EXTRA_EVAL - 1):
data.init_data(t, l, data_size, nclass) data.init_data(t, l, data_size, nclass)
data.init_data(t, data.bins[-2], data_size, nclass) data.init_data(t, data.bins[-2], data_size, nclass)
data.init_data(t, data.bins[-1], data_size, nclass) data.init_data(t, data.bins[-1], data_size, nclass)
...@@ -101,14 +85,14 @@ def initialize(sess): ...@@ -101,14 +85,14 @@ def initialize(sess):
# Print out parameters. # Print out parameters.
curriculum = 0.12 curriculum = 0.12
fin = ("cv %d kw %d h %d kh %d rxr %d bs %d ns %.2f t %s" msg1 = ("layers %d kw %d h %d kh %d relax %d batch %d noise %.2f task %s"
% (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh, FLAGS.rx_step, % (FLAGS.nconvs, FLAGS.kw, FLAGS.height, FLAGS.kh, FLAGS.rx_step,
FLAGS.batch_size, FLAGS.grad_noise_scale, FLAGS.task)) FLAGS.batch_size, FLAGS.grad_noise_scale, FLAGS.task))
fin = "data %d %s" % (FLAGS.train_data_size, fin) msg2 = "data %d %s" % (FLAGS.train_data_size, msg1)
tag = ("df %.2f p %.3f lr %.2f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s" % msg3 = ("cut %.2f pull %.3f lr %.2f iw %.2f cr %.2f nm %d d%.4f gn %.2f %s" %
(FLAGS.cutoff, FLAGS.pull_incr, FLAGS.lr, FLAGS.init_weight, (FLAGS.cutoff, FLAGS.pull_incr, FLAGS.lr, FLAGS.init_weight,
curriculum, FLAGS.nmaps, FLAGS.dropout, FLAGS.max_grad_norm, fin)) curriculum, FLAGS.nmaps, FLAGS.dropout, FLAGS.max_grad_norm, msg2))
data.print_out(tag) data.print_out(msg3)
# Create checkpoint directory if it does not exist. # Create checkpoint directory if it does not exist.
checkpoint_dir = os.path.join(FLAGS.train_dir, "neural_gpu%s" checkpoint_dir = os.path.join(FLAGS.train_dir, "neural_gpu%s"
...@@ -120,7 +104,7 @@ def initialize(sess): ...@@ -120,7 +104,7 @@ def initialize(sess):
# Create model and initialize it. # Create model and initialize it.
tf.get_variable_scope().set_initializer( tf.get_variable_scope().set_initializer(
tf.uniform_unit_scaling_initializer(factor=1.8 * FLAGS.init_weight)) tf.uniform_unit_scaling_initializer(factor=1.8 * FLAGS.init_weight))
model = ngpu.NeuralGPU( model = neural_gpu.NeuralGPU(
FLAGS.nmaps, FLAGS.nmaps, FLAGS.niclass, FLAGS.noclass, FLAGS.dropout, FLAGS.nmaps, FLAGS.nmaps, FLAGS.niclass, FLAGS.noclass, FLAGS.dropout,
FLAGS.rx_step, FLAGS.max_grad_norm, FLAGS.cutoff, FLAGS.nconvs, FLAGS.rx_step, FLAGS.max_grad_norm, FLAGS.cutoff, FLAGS.nconvs,
FLAGS.kw, FLAGS.kh, FLAGS.height, FLAGS.mode, FLAGS.lr, FLAGS.kw, FLAGS.kh, FLAGS.height, FLAGS.mode, FLAGS.lr,
...@@ -145,131 +129,148 @@ def single_test(l, model, sess, task, nprint, batch_size, print_out=True, ...@@ -145,131 +129,148 @@ def single_test(l, model, sess, task, nprint, batch_size, print_out=True,
"""Test model on test data of length l using the given session.""" """Test model on test data of length l using the given session."""
inpt, target = data.get_batch(l, batch_size, False, task, offset) inpt, target = data.get_batch(l, batch_size, False, task, offset)
_, res, _, steps = model.step(sess, inpt, target, False) _, res, _, steps = model.step(sess, inpt, target, False)
errors, total, seq = data.accuracy(inpt, res, target, batch_size, nprint) errors, total, seq_err = data.accuracy(inpt, res, target, batch_size, nprint)
seq = float(seq) / batch_size seq_err = float(seq_err) / batch_size
if total > 0: if total > 0:
errors = float(errors) / total errors = float(errors) / total
if print_out: if print_out:
data.print_out(" %s len %d errors %.2f sequence-errors %.2f" data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
% (task, l, 100*errors, 100*seq)) % (task, l, 100*errors, 100*seq_err))
return errors, seq, (steps, inpt, [np.argmax(o, axis=1) for o in res]) return errors, seq_err, (steps, inpt, [np.argmax(o, axis=1) for o in res])
def multi_test(l, model, sess, task, nprint, batch_size, offset=None): def multi_test(l, model, sess, task, nprint, batch_size, offset=None):
"""Run multiple tests at lower batch size to save memory.""" """Run multiple tests at lower batch size to save memory."""
errors = 0.0 errors, seq_err = 0.0, 0.0
seq = 0.0
to_print = nprint to_print = nprint
low_batch = FLAGS.low_batch_size low_batch = FLAGS.low_batch_size
low_batch = min(low_batch, batch_size) low_batch = min(low_batch, batch_size)
for mstep in xrange(batch_size / low_batch): for mstep in xrange(batch_size / low_batch):
cur_offset = None if offset is None else offset + mstep * low_batch cur_offset = None if offset is None else offset + mstep * low_batch
err, sq, _ = single_test(l, model, sess, task, to_print, low_batch, False, err, sq_err, _ = single_test(l, model, sess, task, to_print, low_batch,
cur_offset) False, cur_offset)
to_print = max(0, to_print - low_batch) to_print = max(0, to_print - low_batch)
errors += err errors += err
seq += sq seq_err += sq_err
if FLAGS.mode > 0: if FLAGS.mode > 0:
cur_errors = float(low_batch * errors) / ((mstep+1) * low_batch) cur_errors = float(low_batch * errors) / ((mstep+1) * low_batch)
cur_seq = float(low_batch * seq) / ((mstep+1) * low_batch) cur_seq_err = float(low_batch * seq_err) / ((mstep+1) * low_batch)
data.print_out(" %s multitest current errors %.2f sequence-errors %.2f" data.print_out(" %s multitest current errors %.2f sequence-errors %.2f"
% (task, 100*cur_errors, 100*cur_seq)) % (task, 100*cur_errors, 100*cur_seq_err))
errors = float(low_batch) * float(errors) / batch_size errors = float(low_batch) * float(errors) / batch_size
seq = float(low_batch) * float(seq) / batch_size seq_err = float(low_batch) * float(seq_err) / batch_size
data.print_out(" %s len %d errors %.2f sequence-errors %.2f" data.print_out(" %s len %d errors %.2f sequence-errors %.2f"
% (task, l, 100*errors, 100*seq)) % (task, l, 100*errors, 100*seq_err))
return errors, seq return errors, seq_err
def train(): def train():
"""Main training function.""" """Train the model."""
batch_size = FLAGS.batch_size batch_size = FLAGS.batch_size
tasks = FLAGS.task.split("-") tasks = FLAGS.task.split("-")
with tf.Session() as sess: with tf.Session() as sess:
model, min_length, max_length, checkpoint_dir, curriculum = initialize(sess) model, min_length, max_length, checkpoint_dir, curriculum = initialize(sess)
max_cur_length = min(min_length + 3, max_length) max_cur_length = min(min_length + 3, max_length)
prev_acc_perp = [1000000 for _ in xrange(3)] prev_acc_perp = [1000000 for _ in xrange(3)]
prev_sq = 1.0 prev_seq_err = 1.0
# Main traning loop.
while True: while True:
global_step, pull, max_cur_length, learning_rate = sess.run( global_step, pull, max_cur_length, learning_rate = sess.run(
[model.global_step, model.pull, model.cur_length, model.lr]) [model.global_step, model.pull, model.cur_length, model.lr])
ep = global_step / FLAGS.steps_per_checkpoint acc_loss, acc_total, acc_errors, acc_seq_err = 0.0, 0, 0, 0
acc_loss, acc_total, acc_errors, acc_seq = 0.0, 0, 0, 0
acc_grad_norm, step_count, step_time = 0.0, 0, 0.0 acc_grad_norm, step_count, step_time = 0.0, 0, 0.0
for _ in xrange(FLAGS.steps_per_checkpoint): for _ in xrange(FLAGS.steps_per_checkpoint):
global_step += 1 global_step += 1
task = random.choice(tasks) task = random.choice(tasks)
l1 = np.random.randint(max_cur_length - min_length + 1) + min_length
l = l1 # Select the length for curriculum learning.
if np.random.randint(10) > 3: # Prefer longer stuff 60% of time. l = np.random.randint(max_cur_length - min_length + 1) + min_length
l = np.random.randint(max_cur_length - min_length+1) + min_length # Prefer longer stuff 60% of time.
if np.random.randint(100) < 60:
l1 = np.random.randint(max_cur_length - min_length+1) + min_length
l = max(l, l1) l = max(l, l1)
if np.random.randint(4) < 1: # Mixed learning: once in a while big. # Mixed curriculum learning: in 25% of cases go to any larger length.
l = np.random.randint(max_length - min_length + 1) + min_length if np.random.randint(100) < 25:
l1 = np.random.randint(max_length - min_length + 1) + min_length
l = max(l, l1) l = max(l, l1)
# Run a step and time it.
start_time = time.time() start_time = time.time()
inp, target = data.get_batch(l, batch_size, True, task) inp, target = data.get_batch(l, batch_size, True, task)
stepp = math.pow(global_step, -0.55) noise_param = math.sqrt(math.pow(global_step, -0.55) *
noise_param = math.sqrt(stepp * 20 * prev_sq) * FLAGS.grad_noise_scale (20 * prev_seq_err)) * FLAGS.grad_noise_scale
loss, res, gnorm, _ = model.step(sess, inp, target, True, noise_param) loss, res, gnorm, _ = model.step(sess, inp, target, True, noise_param)
step_time += time.time() - start_time step_time += time.time() - start_time
acc_grad_norm += float(gnorm) acc_grad_norm += float(gnorm)
# Accumulate statistics only if we did not exceed curriculum length.
if l < max_cur_length + 1: if l < max_cur_length + 1:
step_count += 1 step_count += 1
acc_loss += loss acc_loss += loss
errors, total, seq = data.accuracy(inp, res, target, errors, total, seq_err = data.accuracy(inp, res, target,
batch_size, 0) batch_size, 0)
acc_total += total acc_total += total
acc_errors += errors acc_errors += errors
acc_seq += seq acc_seq_err += seq_err
# Normalize and print out accumulated statistics.
acc_loss /= step_count acc_loss /= step_count
step_time /= FLAGS.steps_per_checkpoint step_time /= FLAGS.steps_per_checkpoint
acc_seq = float(acc_seq) / (step_count * batch_size) acc_seq_err = float(acc_seq_err) / (step_count * batch_size)
prev_sq = acc_seq prev_seq_err = acc_seq_err
acc_errors = float(acc_errors) / acc_total if acc_total > 0 else 1.0 acc_errors = float(acc_errors) / acc_total if acc_total > 0 else 1.0
msg1 = "ep %d st %.2f lr %.8f" % (ep, step_time, learning_rate) msg1 = "step %d step-time %.2f" % (global_step, step_time)
msg2 = "pl %.3f cme %.3f" % (pull, curriculum) msg2 = "lr %.8f pull %.3f" % (learning_rate, pull)
msg = ("%s %s gn %.8f" msg3 = ("%s %s grad-norm %.8f"
% (msg1, msg2, acc_grad_norm / FLAGS.steps_per_checkpoint)) % (msg1, msg2, acc_grad_norm / FLAGS.steps_per_checkpoint))
data.print_out("%s len %d ppx %.8f errs %.2f sq %.2f" % data.print_out("%s len %d ppx %.8f errors %.2f sequence-errors %.2f" %
(msg, max_cur_length, data.safe_exp(acc_loss), (msg3, max_cur_length, data.safe_exp(acc_loss),
100*acc_errors, 100*acc_seq)) 100*acc_errors, 100*acc_seq_err))
if curriculum > acc_seq:
prev_acc_perp.append(1000000) # If errors are below the curriculum threshold, move curriculum forward.
if curriculum > acc_seq_err:
# Increase current length (until the next with training data).
do_incr = True do_incr = True
while do_incr and max_cur_length < max_length: while do_incr and max_cur_length < max_length:
sess.run(model.cur_length_incr_op) sess.run(model.cur_length_incr_op)
for t in tasks: for t in tasks:
if data.train_set[t]: do_incr = False if data.train_set[t]: do_incr = False
# Forget last perplexities if we're not yet at the end.
if max_cur_length < max_length:
prev_acc_perp.append(1000000)
# Either increase pull or, if it's large, average parameters.
if pull < 1: if pull < 1:
sess.run(model.pull_incr_op) sess.run(model.pull_incr_op)
else: else:
data.print_out(" Averaging parameters.") data.print_out(" Averaging parameters.")
sess.run([model.avg_op, model.lr_decay_op]) sess.run([model.avg_op, model.lr_decay_op])
else:
acc_perp = data.safe_exp(acc_loss) # Lower learning rate if we're worse than the last 3 checkpoints.
if acc_perp > max(prev_acc_perp[-3:]): acc_perp = data.safe_exp(acc_loss)
sess.run(model.lr_decay_op) if acc_perp > max(prev_acc_perp[-3:]):
prev_acc_perp.append(acc_perp) sess.run(model.lr_decay_op)
prev_acc_perp.append(acc_perp)
# Save checkpoint.
checkpoint_path = os.path.join(checkpoint_dir, "neural_gpu.ckpt") checkpoint_path = os.path.join(checkpoint_dir, "neural_gpu.ckpt")
model.saver.save(sess, checkpoint_path, model.saver.save(sess, checkpoint_path,
global_step=model.global_step) global_step=model.global_step)
# Run evaluation. # Run evaluation.
should_exit = True
bound = data.bins[-1] + 1 bound = data.bins[-1] + 1
for t in tasks: for t in tasks:
l = min_length l = min_length
while l < max_length + 12 and l < bound: while l < max_length + EXTRA_EVAL and l < bound:
_, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size) _, seq_err, _ = single_test(l, model, sess, t,
FLAGS.nprint, batch_size)
l += 1 l += 1
while l < bound + 1 and not data.test_set[t][l]: while l < bound + 1 and not data.test_set[t][l]:
l += 1 l += 1
if sq < 0.5: if seq_err < 0.5: # Run larger test if we're good enough.
_, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint, _, seq_err = multi_test(data.forward_max, model, sess, t,
batch_size * 4) FLAGS.nprint, batch_size * 4)
if sq > 0.001: should_exit = False if seq_err < 0.01: # Super-large test on 1-task large-forward models.
if should_exit:
if data.forward_max > 4000 and len(tasks) == 1: if data.forward_max > 4000 and len(tasks) == 1:
multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint, multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
batch_size * 16, 0) batch_size * 16, 0)
...@@ -277,14 +278,17 @@ def train(): ...@@ -277,14 +278,17 @@ def train():
def animate(l, test_data, anim_size): def animate(l, test_data, anim_size):
"""Create animation for the given data (hacky matplotlib use).""" """Create animation for the given data (hacky matplotlib use)."""
xf = 12 xf = 12 # Extra frames to slow down at start and end.
fps = 2 fps = 2 # Frames per step.
# Make the figure.
fig = plt.figure(figsize=(16, 9), facecolor="white") fig = plt.figure(figsize=(16, 9), facecolor="white")
ax = fig.add_axes([0, 0, 1, 1], frameon=False, zorder=2) ax = fig.add_axes([0, 0, 1, 1], frameon=False, zorder=2)
ax.set_xticks([i * 24-0.5 for i in xrange(4)]) ax.set_xticks([i * 24-0.5 for i in xrange(4)])
ax.set_xticklabels([]) ax.set_xticklabels([])
ax.set_yticks([i - 0.5 for i in xrange(l+1)]) ax.set_yticks([i - 0.5 for i in xrange(l+1)])
ax.grid(which="major", axis="both", linestyle="-", color="black") ax.grid(which="major", axis="both", linestyle="-", color="black")
# We need text fields.
text_fields = [] text_fields = []
text_size = 24*32/l text_size = 24*32/l
for y in xrange(l): for y in xrange(l):
...@@ -296,11 +300,8 @@ def animate(l, test_data, anim_size): ...@@ -296,11 +300,8 @@ def animate(l, test_data, anim_size):
vmax=1.0, cmap="gray", aspect="auto", origin="upper", vmax=1.0, cmap="gray", aspect="auto", origin="upper",
interpolation="none", animated=True) interpolation="none", animated=True)
im.set_zorder(1) im.set_zorder(1)
def to_symbol(i):
if i == 0: return "" # Main animation step.
if i == 11: return "+"
if i == 12: return "*"
return str(i-1)
def animation_update(frame_no, test_data, xf, im, text_fields): def animation_update(frame_no, test_data, xf, im, text_fields):
"""Update an animation frame.""" """Update an animation frame."""
steps, inpt, out_raw = test_data steps, inpt, out_raw = test_data
...@@ -319,15 +320,17 @@ def animate(l, test_data, anim_size): ...@@ -319,15 +320,17 @@ def animate(l, test_data, anim_size):
if index - 2*xf < length: if index - 2*xf < length:
t.set_text("") t.set_text("")
else: else:
t.set_text(to_symbol(out[i])) t.set_text(data.to_symbol(out[i]))
else: else:
for i, t in enumerate(text_fields): for i, t in enumerate(text_fields):
t.set_text(to_symbol(inpt[i][batch]) if index < xf else "") t.set_text(data.to_symbol(inpt[i][batch]) if index < xf else "")
if index < xf: if index < xf:
im.set_array(np.zeros_like(steps[0][0])) im.set_array(np.zeros_like(steps[0][0]))
else: else:
im.set_array(steps[0][batch]) im.set_array(steps[0][batch])
return im, return im,
# Create the animation and save to mp4.
animation = anim.FuncAnimation( animation = anim.FuncAnimation(
fig, animation_update, blit=True, frames=(l+4*xf)*anim_size*fps, fig, animation_update, blit=True, frames=(l+4*xf)*anim_size*fps,
interval=500/fps, fargs=(test_data, xf, im, text_fields)) interval=500/fps, fargs=(test_data, xf, im, text_fields))
...@@ -343,8 +346,8 @@ def evaluate(): ...@@ -343,8 +346,8 @@ def evaluate():
bound = data.bins[-1] + 1 bound = data.bins[-1] + 1
for t in tasks: for t in tasks:
l = min_length l = min_length
while l < max_length + 12 and l < bound: while l < max_length + EXTRA_EVAL and l < bound:
_, sq, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size) _, seq_err, _ = single_test(l, model, sess, t, FLAGS.nprint, batch_size)
l += 1 l += 1
while l < bound + 1 and not data.test_set[t][l]: while l < bound + 1 and not data.test_set[t][l]:
l += 1 l += 1
...@@ -353,9 +356,9 @@ def evaluate(): ...@@ -353,9 +356,9 @@ def evaluate():
_, _, test_data = single_test(l, model, sess, t, 0, anim_size) _, _, test_data = single_test(l, model, sess, t, 0, anim_size)
animate(l, test_data, anim_size) animate(l, test_data, anim_size)
# More tests. # More tests.
_, sq = multi_test(data.forward_max, model, sess, t, FLAGS.nprint, _, seq_err = multi_test(data.forward_max, model, sess, t, FLAGS.nprint,
batch_size * 4) batch_size * 4)
if sq < 0.01: # More tests. if seq_err < 0.01: # Super-test if we're very good and in large-test mode.
if data.forward_max > 4000 and len(tasks) == 1: if data.forward_max > 4000 and len(tasks) == 1:
multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint, multi_test(data.forward_max, model, sess, tasks[0], FLAGS.nprint,
batch_size * 64, 0) batch_size * 64, 0)
...@@ -365,16 +368,18 @@ def interactive(): ...@@ -365,16 +368,18 @@ def interactive():
"""Interactively probe an existing model.""" """Interactively probe an existing model."""
with tf.Session() as sess: with tf.Session() as sess:
model, _, _, _, _ = initialize(sess) model, _, _, _, _ = initialize(sess)
sys.stdout.write("Input to Neural GPU, e.g., 0 1. Use -1 for PAD.\n")
sys.stdout.write("> ") sys.stdout.write("> ")
sys.stdout.flush() sys.stdout.flush()
inpt = sys.stdin.readline() inpt = sys.stdin.readline()
while inpt: while inpt:
ids = [int(c) for c in inpt.strip()] ids = [data.to_id(s) for s in inpt.strip().split()]
inpt, target = data.get_batch(len(ids), 1, False, "", inpt, target = data.get_batch(len(ids), 1, False, "",
preset=(ids, [0 for _ in ids])) preset=(ids, [0 for _ in ids]))
_, res, _, _ = model.step(sess, inpt, target, False) _, res, _, _ = model.step(sess, inpt, target, False)
res = [np.argmax(o, axis=1) for o in res] res = [np.argmax(o, axis=1) for o in res]
print " ".join([str(output[0]) for output in res]) res = [o for o in res[:len(ids)] if o > 0]
print " " + " ".join([data.to_symbol(output[0]) for output in res])
sys.stdout.write("> ") sys.stdout.write("> ")
sys.stdout.flush() sys.stdout.flush()
inpt = sys.stdin.readline() inpt = sys.stdin.readline()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment