Commit f02e6013 authored by Alexander Gorban's avatar Alexander Gorban
Browse files

Merge remote-tracking branch 'tensorflow/master'

parents f5f1e12a b719165d
official/* @nealwu @k-w-w /official/ @nealwu @k-w-w
research/adversarial_crypto/* @dave-andersen /research/adversarial_crypto/ @dave-andersen
research/adversarial_text/* @rsepassi /research/adversarial_text/ @rsepassi
research/adv_imagenet_models/* @AlexeyKurakin /research/adv_imagenet_models/ @AlexeyKurakin
research/attention_ocr/* @alexgorban /research/attention_ocr/ @alexgorban
research/audioset/* @plakal @dpwe /research/audioset/ @plakal @dpwe
research/autoencoders/* @snurkabill /research/autoencoders/ @snurkabill
research/cognitive_mapping_and_planning/* @s-gupta /research/cognitive_mapping_and_planning/ @s-gupta
research/compression/* @nmjohn /research/compression/ @nmjohn
research/delf/* @andrefaraujo /research/delf/ @andrefaraujo
research/differential_privacy/* @panyx0718 /research/differential_privacy/ @panyx0718
research/domain_adaptation/* @bousmalis @ddohan /research/domain_adaptation/ @bousmalis @dmrd
research/im2txt/* @cshallue /research/gan/ @joel-shor
research/inception/* @shlens @vincentvanhoucke /research/im2txt/ @cshallue
research/learned_optimizer/* @olganw @nirum /research/inception/ @shlens @vincentvanhoucke
research/learning_to_remember_rare_events/* @lukaszkaiser @ofirnachum /research/learned_optimizer/ @olganw @nirum
research/lfads/* @jazcollins @susillo /research/learning_to_remember_rare_events/ @lukaszkaiser @ofirnachum
research/lm_1b/* @oriolvinyals @panyx0718 /research/lfads/ @jazcollins @susillo
research/namignizer/* @knathanieltucker /research/lm_1b/ @oriolvinyals @panyx0718
research/neural_gpu/* @lukaszkaiser /research/namignizer/ @knathanieltucker
research/neural_programmer/* @arvind2505 /research/neural_gpu/ @lukaszkaiser
research/next_frame_prediction/* @panyx0718 /research/neural_programmer/ @arvind2505
research/object_detection/* @jch1 @tombstone @derekjchow @jesu9 @dreamdragon /research/next_frame_prediction/ @panyx0718
research/pcl_rl/* @ofirnachum /research/object_detection/ @jch1 @tombstone @derekjchow @jesu9 @dreamdragon
research/ptn/* @xcyan @arkanath @hellojas @honglaklee /research/pcl_rl/ @ofirnachum
research/real_nvp/* @laurent-dinh /research/ptn/ @xcyan @arkanath @hellojas @honglaklee
research/rebar/* @gjtucker /research/real_nvp/ @laurent-dinh
research/resnet/* @panyx0718 /research/rebar/ @gjtucker
research/skip_thoughts/* @cshallue /research/resnet/ @panyx0718
research/slim/* @sguada @nathansilberman /research/skip_thoughts/ @cshallue
research/street/* @theraysmith /research/slim/ @sguada @nathansilberman
research/swivel/* @waterson /research/street/ @theraysmith
research/syntaxnet/* @calberti @andorardo @bogatyy @markomernick /research/swivel/ @waterson
research/textsum/* @panyx0718 @peterjliu /research/syntaxnet/ @calberti @andorardo @bogatyy @markomernick
research/transformer/* @daviddao /research/tcn/ @coreylynch @sermanet
research/video_prediction/* @cbfinn /research/textsum/ @panyx0718 @peterjliu
samples/* @MarkDaoust /research/transformer/ @daviddao
tutorials/embedding/* @zffchen78 @a-dai /research/video_prediction/ @cbfinn
tutorials/image/* @sherrym @shlens /research/fivo/ @dieterichlawson
tutorials/rnn/* @lukaszkaiser @ebrevdo /samples/ @MarkDaoust
/tutorials/embedding/ @zffchen78 @a-dai
/tutorials/image/ @sherrym @shlens
/tutorials/rnn/ @lukaszkaiser @ebrevdo
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
This directory builds a convolutional neural net to classify the [MNIST This directory builds a convolutional neural net to classify the [MNIST
dataset](http://yann.lecun.com/exdb/mnist/) using the dataset](http://yann.lecun.com/exdb/mnist/) using the
[tf.contrib.data](https://www.tensorflow.org/api_docs/python/tf/contrib/data), [tf.data](https://www.tensorflow.org/api_docs/python/tf/data),
[tf.estimator.Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator), [tf.estimator.Estimator](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator),
and and
[tf.layers](https://www.tensorflow.org/api_docs/python/tf/layers) [tf.layers](https://www.tensorflow.org/api_docs/python/tf/layers)
...@@ -12,18 +12,51 @@ APIs. ...@@ -12,18 +12,51 @@ APIs.
## Setup ## Setup
To begin, you'll simply need the latest version of TensorFlow installed. To begin, you'll simply need the latest version of TensorFlow installed.
Then to train the model, run the following:
First convert the MNIST data to TFRecord file format by running the following: ```
python mnist.py
```
The model will begin training and will automatically evaluate itself on the
validation data.
Illustrative unit tests and benchmarks can be run with:
``` ```
python convert_to_records.py python mnist_test.py
python mnist_test.py --benchmarks=.
``` ```
Then to train the model, run the following: ## Exporting the model
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/programmers_guide/saved_model) format by using the argument `--export_dir`:
``` ```
python mnist.py python mnist.py --export_dir /tmp/mnist_saved_model
```
The SavedModel will be saved in a timestamped directory under `/tmp/mnist_saved_model/` (e.g. `/tmp/mnist_saved_model/1513630966/`).
**Getting predictions with SavedModel**
Use [`saved_model_cli`](https://www.tensorflow.org/programmers_guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
```
saved_model_cli run --dir /tmp/mnist_saved_model/TIMESTAMP --tag_set serve --signature_def classify --inputs image=examples.npy
```
`examples.npy` contains the data from `example5.png` and `example3.png` in a numpy array, in that order. The array values are normalized to values between 0 and 1.
The output should look similar to below:
```
Result for output key classes:
[5 3]
Result for output key probabilities:
[[ 1.53558474e-07 1.95694142e-13 1.31193523e-09 5.47467265e-03
5.85711526e-22 9.94520664e-01 3.48423509e-06 2.65365645e-17
9.78631419e-07 3.15522470e-08]
[ 1.22413359e-04 5.87615965e-08 1.72251271e-06 9.39960718e-01
3.30306928e-11 2.87386645e-02 2.82353517e-02 8.21146413e-18
2.52568233e-03 4.15460236e-04]]
``` ```
The model will begin training and will automatically evaluate itself on the
validation data.
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts MNIST data to TFRecords file format with Example protos.
To read about optimizations that can be applied to the input preprocessing
stage, see: https://www.tensorflow.org/performance/performance_guide#input_pipeline_optimization.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets import mnist
parser = argparse.ArgumentParser()
parser.add_argument('--directory', type=str, default='/tmp/mnist_data',
help='Directory to download data files and write the '
'converted result.')
parser.add_argument('--validation_size', type=int, default=0,
help='Number of examples to separate from the training '
'data for the validation set.')
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert_to(dataset, name, directory):
"""Converts a dataset to TFRecords."""
images = dataset.images
labels = dataset.labels
num_examples = dataset.num_examples
if images.shape[0] != num_examples:
raise ValueError('Images size %d does not match label size %d.' %
(images.shape[0], num_examples))
rows = images.shape[1]
cols = images.shape[2]
depth = images.shape[3]
filename = os.path.join(directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw = images[index].tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
writer.close()
def main(unused_argv):
# Get the data.
datasets = mnist.read_data_sets(FLAGS.directory,
dtype=tf.uint8,
reshape=False,
validation_size=FLAGS.validation_size)
# Convert to Examples and write the result to TFRecords.
convert_to(datasets.train, 'train', FLAGS.directory)
convert_to(datasets.validation, 'validation', FLAGS.directory)
convert_to(datasets.test, 'test', FLAGS.directory)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""tf.data.Dataset interface to the MNIST dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import gzip
import numpy as np
from six.moves import urllib
import tensorflow as tf
def read32(bytestream):
"""Read 4 bytes from bytestream as an unsigned 32-bit integer."""
dt = np.dtype(np.uint32).newbyteorder('>')
return np.frombuffer(bytestream.read(4), dtype=dt)[0]
def check_image_file_header(filename):
"""Validate that filename corresponds to images for the MNIST dataset."""
with open(filename) as f:
magic = read32(f)
num_images = read32(f)
rows = read32(f)
cols = read32(f)
if magic != 2051:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))
if rows != 28 or cols != 28:
raise ValueError(
'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
(f.name, rows, cols))
def check_labels_file_header(filename):
"""Validate that filename corresponds to labels for the MNIST dataset."""
with open(filename) as f:
magic = read32(f)
num_items = read32(f)
if magic != 2049:
raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
f.name))
def download(directory, filename):
"""Download (and unzip) a file from the MNIST dataset, if it doesn't already exist."""
if not tf.gfile.Exists(directory):
tf.gfile.MakeDirs(directory)
filepath = os.path.join(directory, filename)
if tf.gfile.Exists(filepath):
return filepath
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
zipped_filepath = filepath + '.gz'
print('Downloading %s to %s' % (url, zipped_filepath))
urllib.request.urlretrieve(url, zipped_filepath)
with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(zipped_filepath)
return filepath
def dataset(directory, images_file, labels_file):
images_file = download(directory, images_file)
labels_file = download(directory, labels_file)
check_image_file_header(images_file)
check_labels_file_header(labels_file)
def decode_image(image):
# Normalize from [0, 255] to [0.0, 1.0]
image = tf.decode_raw(image, tf.uint8)
image = tf.cast(image, tf.float32)
image = tf.reshape(image, [784])
return image / 255.0
def one_hot_label(label):
label = tf.decode_raw(label, tf.uint8) # tf.string -> tf.uint8
label = tf.reshape(label, []) # label is a scalar
return tf.one_hot(label, 10)
images = tf.data.FixedLengthRecordDataset(
images_file, 28 * 28, header_bytes=16).map(decode_image)
labels = tf.data.FixedLengthRecordDataset(
labels_file, 1, header_bytes=8).map(one_hot_label)
return tf.data.Dataset.zip((images, labels))
def train(directory):
"""tf.data.Dataset object for MNIST training data."""
return dataset(directory, 'train-images-idx3-ubyte',
'train-labels-idx1-ubyte')
def test(directory):
"""tf.data.Dataset object for MNIST test data."""
return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
...@@ -22,228 +22,186 @@ import os ...@@ -22,228 +22,186 @@ import os
import sys import sys
import tensorflow as tf import tensorflow as tf
import dataset
parser = argparse.ArgumentParser()
# Basic model parameters. class Model(object):
parser.add_argument('--batch_size', type=int, default=100, """Class that defines a graph to recognize digits in the MNIST dataset."""
help='Number of images to process in a batch')
parser.add_argument('--data_dir', type=str, default='/tmp/mnist_data',
help='Path to the MNIST data directory.')
parser.add_argument('--model_dir', type=str, default='/tmp/mnist_model',
help='The directory where the model will be stored.')
parser.add_argument('--train_epochs', type=int, default=40,
help='Number of epochs to train.')
parser.add_argument(
'--data_format', type=str, default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. channels_first '
'provides a performance boost on GPU but is not always compatible '
'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.')
_NUM_IMAGES = {
'train': 50000,
'validation': 10000,
}
def input_fn(is_training, filename, batch_size=1, num_epochs=1):
"""A simple input_fn using the tf.data input pipeline."""
def example_parser(serialized_example):
"""Parses a single tf.Example into image and label tensors."""
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape([28 * 28])
# Normalize the values of the image from the range [0, 255] to [-0.5, 0.5]
image = tf.cast(image, tf.float32) / 255 - 0.5
label = tf.cast(features['label'], tf.int32)
return image, tf.one_hot(label, 10)
dataset = tf.data.TFRecordDataset([filename]) def __init__(self, data_format):
"""Creates a model for classifying a hand-written digit.
# Apply dataset transformations
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because MNIST is
# a small dataset, we can easily shuffle the full epoch.
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
# Map example_parser over dataset, and batch results by up to batch_size
dataset = dataset.map(example_parser).prefetch(batch_size)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels
def mnist_model(inputs, mode, data_format):
"""Takes the MNIST inputs and mode and outputs a tensor of logits."""
# Input Layer
# Reshape X to 4-D tensor: [batch_size, width, height, channels]
# MNIST images are 28x28 pixels, and have one color channel
inputs = tf.reshape(inputs, [-1, 28, 28, 1])
if data_format is None:
# When running on GPU, transpose the data from channels_last (NHWC) to
# channels_first (NCHW) to improve performance.
# See https://www.tensorflow.org/performance/performance_guide#data_formats
data_format = ('channels_first' if tf.test.is_built_with_cuda() else
'channels_last')
Args:
data_format: Either 'channels_first' or 'channels_last'.
'channels_first' is typically faster on GPUs while 'channels_last' is
typically faster on CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
"""
if data_format == 'channels_first': if data_format == 'channels_first':
inputs = tf.transpose(inputs, [0, 3, 1, 2]) self._input_shape = [-1, 1, 28, 28]
else:
# Convolutional Layer #1 assert data_format == 'channels_last'
# Computes 32 features using a 5x5 filter with ReLU activation. self._input_shape = [-1, 28, 28, 1]
# Padding is added to preserve width and height.
# Input Tensor Shape: [batch_size, 28, 28, 1] self.conv1 = tf.layers.Conv2D(
# Output Tensor Shape: [batch_size, 28, 28, 32] 32, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
conv1 = tf.layers.conv2d( self.conv2 = tf.layers.Conv2D(
inputs=inputs, 64, 5, padding='same', data_format=data_format, activation=tf.nn.relu)
filters=32, self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu)
kernel_size=[5, 5], self.fc2 = tf.layers.Dense(10)
padding='same', self.dropout = tf.layers.Dropout(0.4)
activation=tf.nn.relu, self.max_pool2d = tf.layers.MaxPooling2D(
data_format=data_format) (2, 2), (2, 2), padding='same', data_format=data_format)
# Pooling Layer #1 def __call__(self, inputs, training):
# First max pooling layer with a 2x2 filter and stride of 2 """Add operations to classify a batch of input images.
# Input Tensor Shape: [batch_size, 28, 28, 32]
# Output Tensor Shape: [batch_size, 14, 14, 32] Args:
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2, inputs: A Tensor representing a batch of input images.
data_format=data_format) training: A boolean. Set to True to add operations required only when
training the classifier.
# Convolutional Layer #2
# Computes 64 features using a 5x5 filter. Returns:
# Padding is added to preserve width and height. A logits Tensor with shape [<batch_size>, 10].
# Input Tensor Shape: [batch_size, 14, 14, 32] """
# Output Tensor Shape: [batch_size, 14, 14, 64] y = tf.reshape(inputs, self._input_shape)
conv2 = tf.layers.conv2d( y = self.conv1(y)
inputs=pool1, y = self.max_pool2d(y)
filters=64, y = self.conv2(y)
kernel_size=[5, 5], y = self.max_pool2d(y)
padding='same', y = tf.layers.flatten(y)
activation=tf.nn.relu, y = self.fc1(y)
data_format=data_format) y = self.dropout(y, training=training)
return self.fc2(y)
# Pooling Layer #2
# Second max pooling layer with a 2x2 filter and stride of 2
# Input Tensor Shape: [batch_size, 14, 14, 64] def model_fn(features, labels, mode, params):
# Output Tensor Shape: [batch_size, 7, 7, 64] """The model_fn argument for creating an Estimator."""
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2, model = Model(params['data_format'])
data_format=data_format) image = features
if isinstance(image, dict):
# Flatten tensor into a batch of vectors image = features['image']
# Input Tensor Shape: [batch_size, 7, 7, 64]
# Output Tensor Shape: [batch_size, 7 * 7 * 64]
pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
# Dense Layer
# Densely connected layer with 1024 neurons
# Input Tensor Shape: [batch_size, 7 * 7 * 64]
# Output Tensor Shape: [batch_size, 1024]
dense = tf.layers.dense(inputs=pool2_flat, units=1024,
activation=tf.nn.relu)
# Add dropout operation; 0.6 probability that element will be kept
dropout = tf.layers.dropout(
inputs=dense, rate=0.4, training=(mode == tf.estimator.ModeKeys.TRAIN))
# Logits layer
# Input Tensor Shape: [batch_size, 1024]
# Output Tensor Shape: [batch_size, 10]
logits = tf.layers.dense(inputs=dropout, units=10)
return logits
def mnist_model_fn(features, labels, mode, params):
"""Model function for MNIST."""
logits = mnist_model(features, mode, params['data_format'])
if mode == tf.estimator.ModeKeys.PREDICT:
logits = model(image, training=False)
predictions = { predictions = {
'classes': tf.argmax(input=logits, axis=1), 'classes': tf.argmax(logits, axis=1),
'probabilities': tf.nn.softmax(logits, name='softmax_tensor') 'probabilities': tf.nn.softmax(logits),
} }
return tf.estimator.EstimatorSpec(
if mode == tf.estimator.ModeKeys.PREDICT: mode=tf.estimator.ModeKeys.PREDICT,
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) predictions=predictions,
export_outputs={
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits) 'classify': tf.estimator.export.PredictOutput(predictions)
})
# Configure the training op
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
train_op = optimizer.minimize(loss, tf.train.get_or_create_global_step()) logits = model(image, training=True)
else: loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
train_op = None
accuracy = tf.metrics.accuracy( accuracy = tf.metrics.accuracy(
tf.argmax(labels, axis=1), predictions['classes']) labels=tf.argmax(labels, axis=1), predictions=tf.argmax(logits, axis=1))
metrics = {'accuracy': accuracy} # Name the accuracy tensor 'train_accuracy' to demonstrate the
# LoggingTensorHook.
# Create a tensor named train_accuracy for logging purposes
tf.identity(accuracy[1], name='train_accuracy') tf.identity(accuracy[1], name='train_accuracy')
tf.summary.scalar('train_accuracy', accuracy[1]) tf.summary.scalar('train_accuracy', accuracy[1])
return tf.estimator.EstimatorSpec( return tf.estimator.EstimatorSpec(
mode=mode, mode=tf.estimator.ModeKeys.TRAIN,
predictions=predictions, loss=loss,
train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
if mode == tf.estimator.ModeKeys.EVAL:
logits = model(image, training=False)
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)
return tf.estimator.EstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=loss, loss=loss,
train_op=train_op, eval_metric_ops={
eval_metric_ops=metrics) 'accuracy':
tf.metrics.accuracy(
labels=tf.argmax(labels, axis=1),
predictions=tf.argmax(logits, axis=1)),
})
def main(unused_argv): def main(unused_argv):
# Make sure that training and testing data have been converted. data_format = FLAGS.data_format
train_file = os.path.join(FLAGS.data_dir, 'train.tfrecords') if data_format is None:
test_file = os.path.join(FLAGS.data_dir, 'test.tfrecords') data_format = ('channels_first'
assert (tf.gfile.Exists(train_file) and tf.gfile.Exists(test_file)), ( if tf.test.is_built_with_cuda() else 'channels_last')
'Run convert_to_records.py first to convert the MNIST data to TFRecord '
'file format.')
# Create the Estimator
mnist_classifier = tf.estimator.Estimator( mnist_classifier = tf.estimator.Estimator(
model_fn=mnist_model_fn, model_dir=FLAGS.model_dir, model_fn=model_fn,
params={'data_format': FLAGS.data_format}) model_dir=FLAGS.model_dir,
params={
'data_format': data_format
})
# Train the model
def train_input_fn():
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes use less memory. MNIST is a small
# enough dataset that we can easily shuffle the full epoch.
ds = dataset.train(FLAGS.data_dir)
ds = ds.cache().shuffle(buffer_size=50000).batch(FLAGS.batch_size).repeat(
FLAGS.train_epochs)
(images, labels) = ds.make_one_shot_iterator().get_next()
return (images, labels)
# Set up training hook that logs the training accuracy every 100 steps. # Set up training hook that logs the training accuracy every 100 steps.
tensors_to_log = { tensors_to_log = {'train_accuracy': 'train_accuracy'}
'train_accuracy': 'train_accuracy'
}
logging_hook = tf.train.LoggingTensorHook( logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100) tensors=tensors_to_log, every_n_iter=100)
mnist_classifier.train(input_fn=train_input_fn, hooks=[logging_hook])
# Train the model
mnist_classifier.train(
input_fn=lambda: input_fn(
True, train_file, FLAGS.batch_size, FLAGS.train_epochs),
hooks=[logging_hook])
# Evaluate the model and print results # Evaluate the model and print results
eval_results = mnist_classifier.evaluate( def eval_input_fn():
input_fn=lambda: input_fn(False, test_file, FLAGS.batch_size)) return dataset.test(FLAGS.data_dir).batch(
FLAGS.batch_size).make_one_shot_iterator().get_next()
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print() print()
print('Evaluation results:\n\t%s' % eval_results) print('Evaluation results:\n\t%s' % eval_results)
# Export the model
if FLAGS.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28])
input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
'image': image,
})
mnist_classifier.export_savedmodel(FLAGS.export_dir, input_fn)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='Number of images to process in a batch')
parser.add_argument(
'--data_dir',
type=str,
default='/tmp/mnist_data',
help='Path to directory containing the MNIST dataset')
parser.add_argument(
'--model_dir',
type=str,
default='/tmp/mnist_model',
help='The directory where the model will be stored.')
parser.add_argument(
'--train_epochs', type=int, default=40, help='Number of epochs to train.')
parser.add_argument(
'--data_format',
type=str,
default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. channels_first '
'provides a performance boost on GPU but is not always compatible '
'with CPU. If left unspecified, the data format will be chosen '
'automatically based on whether TensorFlow was built for CPU or GPU.')
parser.add_argument(
'--export_dir',
type=str,
help='The directory where the exported SavedModel will be stored.')
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
...@@ -18,25 +18,58 @@ from __future__ import division ...@@ -18,25 +18,58 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import tensorflow as tf import tensorflow as tf
import time
import mnist import mnist
tf.logging.set_verbosity(tf.logging.ERROR) BATCH_SIZE = 100
class BaseTest(tf.test.TestCase): def dummy_input_fn():
image = tf.random_uniform([BATCH_SIZE, 784])
labels = tf.random_uniform([BATCH_SIZE], maxval=9, dtype=tf.int32)
return image, tf.one_hot(labels, 10)
def input_fn(self):
features = tf.random_uniform([55000, 784]) def make_estimator():
labels = tf.random_uniform([55000], maxval=9, dtype=tf.int32) data_format = 'channels_last'
return features, tf.one_hot(labels, 10) if tf.test.is_built_with_cuda():
data_format = 'channels_first'
return tf.estimator.Estimator(
model_fn=mnist.model_fn, params={
'data_format': data_format
})
class Tests(tf.test.TestCase):
def test_mnist(self):
classifier = make_estimator()
classifier.train(input_fn=dummy_input_fn, steps=2)
eval_results = classifier.evaluate(input_fn=dummy_input_fn, steps=1)
loss = eval_results['loss']
global_step = eval_results['global_step']
accuracy = eval_results['accuracy']
self.assertEqual(loss.shape, ())
self.assertEqual(2, global_step)
self.assertEqual(accuracy.shape, ())
input_fn = lambda: tf.random_uniform([3, 784])
predictions_generator = classifier.predict(input_fn)
for i in range(3):
predictions = next(predictions_generator)
self.assertEqual(predictions['probabilities'].shape, (10,))
self.assertEqual(predictions['classes'].shape, ())
def mnist_model_fn_helper(self, mode): def mnist_model_fn_helper(self, mode):
features, labels = self.input_fn() features, labels = dummy_input_fn()
image_count = features.shape[0] image_count = features.shape[0]
spec = mnist.mnist_model_fn( spec = mnist.model_fn(features, labels, mode, {
features, labels, mode, {'data_format': 'channels_last'}) 'data_format': 'channels_last'
})
if mode == tf.estimator.ModeKeys.PREDICT:
predictions = spec.predictions predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape, (image_count, 10)) self.assertAllEqual(predictions['probabilities'].shape, (image_count, 10))
self.assertEqual(predictions['probabilities'].dtype, tf.float32) self.assertEqual(predictions['probabilities'].dtype, tf.float32)
...@@ -65,5 +98,31 @@ class BaseTest(tf.test.TestCase): ...@@ -65,5 +98,31 @@ class BaseTest(tf.test.TestCase):
self.mnist_model_fn_helper(tf.estimator.ModeKeys.PREDICT) self.mnist_model_fn_helper(tf.estimator.ModeKeys.PREDICT)
class Benchmarks(tf.test.Benchmark):
def benchmark_train_step_time(self):
classifier = make_estimator()
# Run one step to warmup any use of the GPU.
classifier.train(input_fn=dummy_input_fn, steps=1)
have_gpu = tf.test.is_gpu_available()
num_steps = 1000 if have_gpu else 100
name = 'train_step_time_%s' % ('gpu' if have_gpu else 'cpu')
start = time.time()
classifier.train(input_fn=dummy_input_fn, steps=num_steps)
end = time.time()
wall_time = (end - start) / num_steps
self.report_benchmark(
iters=num_steps,
wall_time=wall_time,
name=name,
extras={
'examples_per_sec': BATCH_SIZE / wall_time
})
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.ERROR)
tf.test.main() tf.test.main()
...@@ -46,3 +46,6 @@ python imagenet_main.py --data_dir=/path/to/imagenet ...@@ -46,3 +46,6 @@ python imagenet_main.py --data_dir=/path/to/imagenet
The model will begin training and will automatically evaluate itself on the validation data roughly once per epoch. The model will begin training and will automatically evaluate itself on the validation data roughly once per epoch.
Note that there are a number of other options you can specify, including `--model_dir` to choose where to store the model and `--resnet_size` to choose the model size (options include ResNet-18 through ResNet-200). See [`imagenet_main.py`](imagenet_main.py) for the full list of options. Note that there are a number of other options you can specify, including `--model_dir` to choose where to store the model and `--resnet_size` to choose the model size (options include ResNet-18 through ResNet-200). See [`imagenet_main.py`](imagenet_main.py) for the full list of options.
### Pre-trained model
You can download a 190 MB pre-trained version of ResNet-50 achieving 75.3% top-1 single-crop accuracy here: [resnet50_2017_11_30.tar.gz](http://download.tensorflow.org/models/official/resnet50_2017_11_30.tar.gz). Simply download and uncompress the file, and point the model to the extracted directory using the `--model_dir` flag.
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Example code for TensorFlow Wide & Deep Tutorial using TF.Learn API.""" """Example code for TensorFlow Wide & Deep Tutorial using tf.estimator API."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -18,7 +18,7 @@ installation](https://www.tensorflow.org/install). ...@@ -18,7 +18,7 @@ installation](https://www.tensorflow.org/install).
- [attention_ocr](attention_ocr): a model for real-world image text - [attention_ocr](attention_ocr): a model for real-world image text
extraction. extraction.
- [audioset](audioset): Models and supporting code for use with - [audioset](audioset): Models and supporting code for use with
[AudioSet](http://g.co.audioset). [AudioSet](http://g.co/audioset).
- [autoencoder](autoencoder): various autoencoders. - [autoencoder](autoencoder): various autoencoders.
- [cognitive_mapping_and_planning](cognitive_mapping_and_planning): - [cognitive_mapping_and_planning](cognitive_mapping_and_planning):
implementation of a spatial memory based mapping and planning architecture implementation of a spatial memory based mapping and planning architecture
...@@ -29,6 +29,7 @@ installation](https://www.tensorflow.org/install). ...@@ -29,6 +29,7 @@ installation](https://www.tensorflow.org/install).
- [differential_privacy](differential_privacy): privacy-preserving student - [differential_privacy](differential_privacy): privacy-preserving student
models from multiple teachers. models from multiple teachers.
- [domain_adaptation](domain_adaptation): domain separation networks. - [domain_adaptation](domain_adaptation): domain separation networks.
- [gan](gan): generative adversarial networks.
- [im2txt](im2txt): image-to-text neural network for image captioning. - [im2txt](im2txt): image-to-text neural network for image captioning.
- [inception](inception): deep convolutional networks for computer vision. - [inception](inception): deep convolutional networks for computer vision.
- [learning_to_remember_rare_events](learning_to_remember_rare_events): a - [learning_to_remember_rare_events](learning_to_remember_rare_events): a
...@@ -60,6 +61,7 @@ installation](https://www.tensorflow.org/install). ...@@ -60,6 +61,7 @@ installation](https://www.tensorflow.org/install).
using a Deep RNN. using a Deep RNN.
- [swivel](swivel): the Swivel algorithm for generating word embeddings. - [swivel](swivel): the Swivel algorithm for generating word embeddings.
- [syntaxnet](syntaxnet): neural models of natural language syntax. - [syntaxnet](syntaxnet): neural models of natural language syntax.
- [tcn](tcn): Self-supervised representation learning from multi-view video.
- [textsum](textsum): sequence-to-sequence with attention model for text - [textsum](textsum): sequence-to-sequence with attention model for text
summarization. summarization.
- [transformer](transformer): spatial transformer network, which allows the - [transformer](transformer): spatial transformer network, which allows the
......
...@@ -29,6 +29,7 @@ Network Architecture | Adversarial training | Checkpoint ...@@ -29,6 +29,7 @@ Network Architecture | Adversarial training | Checkpoint
Inception v3 | Step L.L. | [adv_inception_v3_2017_08_18.tar.gz](http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz) Inception v3 | Step L.L. | [adv_inception_v3_2017_08_18.tar.gz](http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz)
Inception v3 | Step L.L. on ensemble of 3 models | [ens3_adv_inception_v3_2017_08_18.tar.gz](http://download.tensorflow.org/models/ens3_adv_inception_v3_2017_08_18.tar.gz) Inception v3 | Step L.L. on ensemble of 3 models | [ens3_adv_inception_v3_2017_08_18.tar.gz](http://download.tensorflow.org/models/ens3_adv_inception_v3_2017_08_18.tar.gz)
Inception v3 | Step L.L. on ensemble of 4 models| [ens4_adv_inception_v3_2017_08_18.tar.gz](http://download.tensorflow.org/models/ens4_adv_inception_v3_2017_08_18.tar.gz) Inception v3 | Step L.L. on ensemble of 4 models| [ens4_adv_inception_v3_2017_08_18.tar.gz](http://download.tensorflow.org/models/ens4_adv_inception_v3_2017_08_18.tar.gz)
Inception ResNet v2 | Step L.L. | [adv_inception_resnet_v2_2017_12_18.tar.gz](http://download.tensorflow.org/models/adv_inception_resnet_v2_2017_12_18.tar.gz)
Inception ResNet v2 | Step L.L. on ensemble of 3 models | [ens_adv_inception_resnet_v2_2017_08_18.tar.gz](http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz) Inception ResNet v2 | Step L.L. on ensemble of 3 models | [ens_adv_inception_resnet_v2_2017_08_18.tar.gz](http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz)
All checkpoints are compatible with All checkpoints are compatible with
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import json import json
from six.moves import xrange
import tensorflow as tf import tensorflow as tf
from entropy_coder.lib import blocks from entropy_coder.lib import blocks
......
# Filtering Variational Objectives
This folder contains a TensorFlow implementation of the algorithms from
Chris J. Maddison\*, Dieterich Lawson\*, George Tucker\*, Nicolas Heess, Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, and Yee Whye Teh. "Filtering Variational Objectives." NIPS 2017.
[https://arxiv.org/abs/1705.09279](https://arxiv.org/abs/1705.09279)
This code implements 3 different bounds for training sequential latent variable models: the evidence lower bound (ELBO), the importance weighted auto-encoder bound (IWAE), and our bound, the filtering variational objective (FIVO).
Additionally it contains an implementation of the variational recurrent neural network (VRNN), a sequential latent variable model that can be trained using these three objectives. This repo provides code for training a VRNN to do sequence modeling of pianoroll and speech data.
#### Directory Structure
The important parts of the code are organized as follows.
```
fivo.py # main script, contains flag definitions
runners.py # graph construction code for training and evaluation
bounds.py # code for computing each bound
data
├── datasets.py # readers for pianoroll and speech datasets
├── calculate_pianoroll_mean.py # preprocesses the pianoroll datasets
└── create_timit_dataset.py # preprocesses the TIMIT dataset
models
└── vrnn.py # variational RNN implementation
bin
├── run_train.sh # an example script that runs training
├── run_eval.sh # an example script that runs evaluation
└── download_pianorolls.sh # a script that downloads the pianoroll files
```
### Training on Pianorolls
Requirements before we start:
* TensorFlow (see [tensorflow.org](http://tensorflow.org) for how to install)
* [scipy](https://www.scipy.org/)
* [sonnet](https://github.com/deepmind/sonnet)
#### Download the Data
The pianoroll datasets are encoded as pickled sparse arrays and are available at [http://www-etud.iro.umontreal.ca/~boulanni/icml2012](http://www-etud.iro.umontreal.ca/~boulanni/icml2012). You can use the script `bin/download_pianorolls.sh` to download the files into a directory of your choosing.
```
export PIANOROLL_DIR=~/pianorolls
mkdir $PIANOROLL_DIR
sh bin/download_pianorolls.sh $PIANOROLL_DIR
```
#### Preprocess the Data
The script `calculate_pianoroll_mean.py` loads a pianoroll pickle file, calculates the mean, updates the pickle file to include the mean under the key `train_mean`, and writes the file back to disk in-place. You should do this for all pianoroll datasets you wish to train on.
```
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/piano-midi.de.pkl
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/nottingham.de.pkl
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/musedata.pkl
python data/calculate_pianoroll_mean.py --in_file=$PIANOROLL_DIR/jsb.pkl
```
#### Training
Now we can train a model. Here is a standard training run, taken from `bin/run_train.sh`:
```
python fivo.py \
--mode=train \
--logdir=/tmp/fivo \
--model=vrnn \
--bound=fivo \
--summarize_every=100 \
--batch_size=4 \
--num_samples=4 \
--learning_rate=0.0001 \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll"
```
You should see output that looks something like this (with a lot of extra logging cruft):
```
Step 1, fivo bound per timestep: -11.801050
global_step/sec: 9.89825
Step 101, fivo bound per timestep: -11.198309
global_step/sec: 9.55475
Step 201, fivo bound per timestep: -11.287262
global_step/sec: 9.68146
step 301, fivo bound per timestep: -11.316490
global_step/sec: 9.94295
Step 401, fivo bound per timestep: -11.151743
```
You will also see lines saying `Out of range: exceptions.StopIteration: Iteration finished`. This is not an error and is fine.
#### Evaluation
You can also evaluate saved checkpoints. The `eval` mode loads a model checkpoint, tests its performance on all items in a dataset, and reports the log-likelihood averaged over the dataset. For example here is a command, taken from `bin/run_eval.sh`, that will evaluate a JSB model on the test set:
```
python fivo.py \
--mode=eval \
--split=test \
--alsologtostderr \
--logdir=/tmp/fivo \
--model=vrnn \
--batch_size=4 \
--num_samples=4 \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll"
```
You should see output like this:
```
Model restored from step 1, evaluating.
test elbo ll/t: -12.299635, iwae ll/t: -12.128336 fivo ll/t: -11.656939
test elbo ll/seq: -754.750312, iwae ll/seq: -744.238773 fivo ll/seq: -715.3121490
```
The evaluation script prints log-likelihood in both nats per timestep (ll/t) and nats per sequence (ll/seq) for all three bounds.
### Training on TIMIT
The TIMIT speech dataset is available at the [Linguistic Data Consortium website](https://catalog.ldc.upenn.edu/LDC93S1), but is unfortunately not free. These instructions will proceed assuming you have downloaded the TIMIT archive and extracted it into the directory `$RAW_TIMIT_DIR`.
#### Preprocess TIMIT
We preprocess TIMIT (as described in our paper) and write it out to a series of TFRecord files. To prepare the TIMIT dataset use the script `create_timit_dataset.py`
```
export $TIMIT_DIR=~/timit_dataset
mkdir $TIMIT_DIR
python data/create_timit_dataset.py \
--raw_timit_dir=$RAW_TIMIT_DIR \
--out_dir=$TIMIT_DIR
```
You should see this exact output:
```
4389 train / 231 valid / 1680 test
train mean: 0.006060 train std: 548.136169
```
#### Training on TIMIT
This is very similar to training on pianoroll datasets, with just a few flags switched.
```
python fivo.py \
--mode=train \
--logdir=/tmp/fivo \
--model=vrnn \
--bound=fivo \
--summarize_every=100 \
--batch_size=4 \
--num_samples=4 \
--learning_rate=0.0001 \
--dataset_path="$TIMIT_DIR/train" \
--dataset_type="speech"
```
### Contact
This codebase is maintained by Dieterich Lawson, reachable via email at dieterichl@google.com. For questions and issues please open an issue on the tensorflow/models issues tracker and assign it to @dieterichlawson.
#!/bin/bash
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# A script to download the pianoroll datasets.
# Accepts one argument, the directory to put the files in
if [ -z "$1" ]
then
echo "Error, must provide a directory to download the files to."
exit
fi
echo "Downloading datasets into $1"
curl -s "http://www-etud.iro.umontreal.ca/~boulanni/Piano-midi.de.pickle" > $1/piano-midi.de.pkl
curl -s "http://www-etud.iro.umontreal.ca/~boulanni/Nottingham.pickle" > $1/nottingham.pkl
curl -s "http://www-etud.iro.umontreal.ca/~boulanni/MuseData.pickle" > $1/musedata.pkl
curl -s "http://www-etud.iro.umontreal.ca/~boulanni/JSB%20Chorales.pickle" > $1/jsb.pkl
#!/bin/bash
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# An example of running evaluation.
PIANOROLL_DIR=$HOME/pianorolls
python fivo.py \
--mode=eval \
--logdir=/tmp/fivo \
--model=vrnn \
--batch_size=4 \
--num_samples=4 \
--split=test \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll"
#!/bin/bash
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# An example of running training.
PIANOROLL_DIR=$HOME/pianorolls
python fivo.py \
--mode=train \
--logdir=/tmp/fivo \
--model=vrnn \
--bound=fivo \
--summarize_every=100 \
--batch_size=4 \
--num_samples=4 \
--learning_rate=0.0001 \
--dataset_path="$PIANOROLL_DIR/jsb.pkl" \
--dataset_type="pianoroll"
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of objectives for training stochastic latent variable models.
Contains implementations of the Importance Weighted Autoencoder objective (IWAE)
and the Filtering Variational objective (FIVO).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import nested_utils as nested
def iwae(cell,
inputs,
seq_lengths,
num_samples=1,
parallel_iterations=30,
swap_memory=True):
"""Computes the IWAE lower bound on the log marginal probability.
This method accepts a stochastic latent variable model and some observations
and computes a stochastic lower bound on the log marginal probability of the
observations. The IWAE estimator is defined by averaging multiple importance
weights. For more details see "Importance Weighted Autoencoders" by Burda
et al. https://arxiv.org/abs/1509.00519.
When num_samples = 1, this bound becomes the evidence lower bound (ELBO).
Args:
cell: A callable that implements one timestep of the model. See
models/vrnn.py for an example.
inputs: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively. At each
timestep 'cell' will be called with a slice of the Tensors in inputs.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
num_samples: The number of samples to use.
parallel_iterations: The number of parallel iterations to use for the
internal while loop.
swap_memory: Whether GPU-CPU memory swapping should be enabled for the
internal while loop.
Returns:
log_p_hat: A Tensor of shape [batch_size] containing IWAE's estimate of the
log marginal probability of the observations.
kl: A Tensor of shape [batch_size] containing the kl divergence
from q(z|x) to p(z), averaged over samples.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep. Will not be valid for
timesteps past the end of a sequence.
log_ess: A Tensor of shape [max_seq_len, batch_size] containing the log
effective sample size at each timestep. Will not be valid for timesteps
past the end of a sequence.
"""
batch_size = tf.shape(seq_lengths)[0]
max_seq_len = tf.reduce_max(seq_lengths)
seq_mask = tf.transpose(
tf.sequence_mask(seq_lengths, maxlen=max_seq_len, dtype=tf.float32),
perm=[1, 0])
if num_samples > 1:
inputs, seq_mask = nested.tile_tensors([inputs, seq_mask], [1, num_samples])
inputs_ta, mask_ta = nested.tas_for_tensors([inputs, seq_mask], max_seq_len)
t0 = tf.constant(0, tf.int32)
init_states = cell.zero_state(batch_size * num_samples, tf.float32)
ta_names = ['log_weights', 'log_ess']
tas = [tf.TensorArray(tf.float32, max_seq_len, name='%s_ta' % n)
for n in ta_names]
log_weights_acc = tf.zeros([num_samples, batch_size], dtype=tf.float32)
kl_acc = tf.zeros([num_samples * batch_size], dtype=tf.float32)
accs = (log_weights_acc, kl_acc)
def while_predicate(t, *unused_args):
return t < max_seq_len
def while_step(t, rnn_state, tas, accs):
"""Implements one timestep of IWAE computation."""
log_weights_acc, kl_acc = accs
cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
# Run the cell for one step.
log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
cur_inputs,
rnn_state,
cur_mask,
)
# Compute the incremental weight and use it to update the current
# accumulated weight.
kl_acc += kl * cur_mask
log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
log_weights_acc += log_alpha
# Calculate the effective sample size.
ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
log_ess = ess_num - ess_denom
# Update the Tensorarrays and accumulators.
ta_updates = [log_weights_acc, log_ess]
new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
new_accs = (log_weights_acc, kl_acc)
return t + 1, new_state, new_tas, new_accs
_, _, tas, accs = tf.while_loop(
while_predicate,
while_step,
loop_vars=(t0, init_states, tas, accs),
parallel_iterations=parallel_iterations,
swap_memory=swap_memory)
log_weights, log_ess = [x.stack() for x in tas]
final_log_weights, kl = accs
log_p_hat = (tf.reduce_logsumexp(final_log_weights, axis=0) -
tf.log(tf.to_float(num_samples)))
kl = tf.reduce_mean(tf.reshape(kl, [num_samples, batch_size]), axis=0)
log_weights = tf.transpose(log_weights, perm=[0, 2, 1])
return log_p_hat, kl, log_weights, log_ess
def ess_criterion(num_samples, log_ess, unused_t):
"""A criterion that resamples based on effective sample size."""
return log_ess <= tf.log(num_samples / 2.0)
def never_resample_criterion(unused_num_samples, log_ess, unused_t):
"""A criterion that never resamples."""
return tf.cast(tf.zeros_like(log_ess), tf.bool)
def always_resample_criterion(unused_num_samples, log_ess, unused_t):
"""A criterion resamples at every timestep."""
return tf.cast(tf.ones_like(log_ess), tf.bool)
def fivo(cell,
inputs,
seq_lengths,
num_samples=1,
resampling_criterion=ess_criterion,
parallel_iterations=30,
swap_memory=True,
random_seed=None):
"""Computes the FIVO lower bound on the log marginal probability.
This method accepts a stochastic latent variable model and some observations
and computes a stochastic lower bound on the log marginal probability of the
observations. The lower bound is defined by a particle filter's unbiased
estimate of the marginal probability of the observations. For more details see
"Filtering Variational Objectives" by Maddison et al.
https://arxiv.org/abs/1705.09279.
When the resampling criterion is "never resample", this bound becomes IWAE.
Args:
cell: A callable that implements one timestep of the model. See
models/vrnn.py for an example.
inputs: The inputs to the model. A potentially nested list or tuple of
Tensors each of shape [max_seq_len, batch_size, ...]. The Tensors must
have a rank at least two and have matching shapes in the first two
dimensions, which represent time and the batch respectively. At each
timestep 'cell' will be called with a slice of the Tensors in inputs.
seq_lengths: A [batch_size] Tensor of ints encoding the length of each
sequence in the batch (sequences can be padded to a common length).
num_samples: The number of particles to use in each particle filter.
resampling_criterion: The resampling criterion to use for this particle
filter. Must accept the number of samples, the effective sample size,
and the current timestep and return a boolean Tensor of shape [batch_size]
indicating whether each particle filter should resample. See
ess_criterion and related functions defined in this file for examples.
parallel_iterations: The number of parallel iterations to use for the
internal while loop. Note that values greater than 1 can introduce
non-determinism even when random_seed is provided.
swap_memory: Whether GPU-CPU memory swapping should be enabled for the
internal while loop.
random_seed: The random seed to pass to the resampling operations in
the particle filter. Mainly useful for testing.
Returns:
log_p_hat: A Tensor of shape [batch_size] containing FIVO's estimate of the
log marginal probability of the observations.
kl: A Tensor of shape [batch_size] containing the sum over time of the kl
divergence from q_t(z_t|x) to p_t(z_t), averaged over particles. Note that
this includes kl terms from trajectories that are culled during resampling
steps.
log_weights: A Tensor of shape [max_seq_len, batch_size, num_samples]
containing the log weights at each timestep of the particle filter. Note
that on timesteps when a resampling operation is performed the log weights
are reset to 0. Will not be valid for timesteps past the end of a
sequence.
log_ess: A Tensor of shape [max_seq_len, batch_size] containing the log
effective sample size of each particle filter at each timestep. Will not
be valid for timesteps past the end of a sequence.
resampled: A Tensor of shape [max_seq_len, batch_size] indicating when the
particle filters resampled. Will be 1.0 on timesteps when resampling
occurred and 0.0 on timesteps when it did not.
"""
# batch_size represents the number of particle filters running in parallel.
batch_size = tf.shape(seq_lengths)[0]
max_seq_len = tf.reduce_max(seq_lengths)
seq_mask = tf.transpose(
tf.sequence_mask(seq_lengths, maxlen=max_seq_len, dtype=tf.float32),
perm=[1, 0])
# Each sequence in the batch will be the input data for a different
# particle filter. The batch will be laid out as:
# particle 1 of particle filter 1
# particle 1 of particle filter 2
# ...
# particle 1 of particle filter batch_size
# particle 2 of particle filter 1
# ...
# particle num_samples of particle filter batch_size
if num_samples > 1:
inputs, seq_mask = nested.tile_tensors([inputs, seq_mask], [1, num_samples])
inputs_ta, mask_ta = nested.tas_for_tensors([inputs, seq_mask], max_seq_len)
t0 = tf.constant(0, tf.int32)
init_states = cell.zero_state(batch_size * num_samples, tf.float32)
ta_names = ['log_weights', 'log_ess', 'resampled']
tas = [tf.TensorArray(tf.float32, max_seq_len, name='%s_ta' % n)
for n in ta_names]
log_weights_acc = tf.zeros([num_samples, batch_size], dtype=tf.float32)
log_p_hat_acc = tf.zeros([batch_size], dtype=tf.float32)
kl_acc = tf.zeros([num_samples * batch_size], dtype=tf.float32)
accs = (log_weights_acc, log_p_hat_acc, kl_acc)
def while_predicate(t, *unused_args):
return t < max_seq_len
def while_step(t, rnn_state, tas, accs):
"""Implements one timestep of FIVO computation."""
log_weights_acc, log_p_hat_acc, kl_acc = accs
cur_inputs, cur_mask = nested.read_tas([inputs_ta, mask_ta], t)
# Run the cell for one step.
log_q_z, log_p_z, log_p_x_given_z, kl, new_state = cell(
cur_inputs,
rnn_state,
cur_mask,
)
# Compute the incremental weight and use it to update the current
# accumulated weight.
kl_acc += kl * cur_mask
log_alpha = (log_p_x_given_z + log_p_z - log_q_z) * cur_mask
log_alpha = tf.reshape(log_alpha, [num_samples, batch_size])
log_weights_acc += log_alpha
# Calculate the effective sample size.
ess_num = 2 * tf.reduce_logsumexp(log_weights_acc, axis=0)
ess_denom = tf.reduce_logsumexp(2 * log_weights_acc, axis=0)
log_ess = ess_num - ess_denom
# Calculate the ancestor indices via resampling. Because we maintain the
# log unnormalized weights, we pass the weights in as logits, allowing
# the distribution object to apply a softmax and normalize them.
resampling_dist = tf.contrib.distributions.Categorical(
logits=tf.transpose(log_weights_acc, perm=[1, 0]))
ancestor_inds = tf.stop_gradient(
resampling_dist.sample(sample_shape=num_samples, seed=random_seed))
# Because the batch is flattened and laid out as discussed
# above, we must modify ancestor_inds to index the proper samples.
# The particles in the ith filter are distributed every batch_size rows
# in the batch, and offset i rows from the top. So, to correct the indices
# we multiply by the batch_size and add the proper offset. Crucially,
# when ancestor_inds is flattened the layout of the batch is maintained.
offset = tf.expand_dims(tf.range(batch_size), 0)
ancestor_inds = tf.reshape(ancestor_inds * batch_size + offset, [-1])
noresample_inds = tf.range(num_samples * batch_size)
# Decide whether or not we should resample; don't resample if we are past
# the end of a sequence.
should_resample = resampling_criterion(num_samples, log_ess, t)
should_resample = tf.logical_and(should_resample,
cur_mask[:batch_size] > 0.)
float_should_resample = tf.to_float(should_resample)
ancestor_inds = tf.where(
tf.tile(should_resample, [num_samples]),
ancestor_inds,
noresample_inds)
new_state = nested.gather_tensors(new_state, ancestor_inds)
# Update the TensorArrays before we reset the weights so that we capture
# the incremental weights and not zeros.
ta_updates = [log_weights_acc, log_ess, float_should_resample]
new_tas = [ta.write(t, x) for ta, x in zip(tas, ta_updates)]
# For the particle filters that resampled, update log_p_hat and
# reset weights to zero.
log_p_hat_update = tf.reduce_logsumexp(
log_weights_acc, axis=0) - tf.log(tf.to_float(num_samples))
log_p_hat_acc += log_p_hat_update * float_should_resample
log_weights_acc *= (1. - tf.tile(float_should_resample[tf.newaxis, :],
[num_samples, 1]))
new_accs = (log_weights_acc, log_p_hat_acc, kl_acc)
return t + 1, new_state, new_tas, new_accs
_, _, tas, accs = tf.while_loop(
while_predicate,
while_step,
loop_vars=(t0, init_states, tas, accs),
parallel_iterations=parallel_iterations,
swap_memory=swap_memory)
log_weights, log_ess, resampled = [x.stack() for x in tas]
final_log_weights, log_p_hat, kl = accs
# Add in the final weight update to log_p_hat.
log_p_hat += (tf.reduce_logsumexp(final_log_weights, axis=0) -
tf.log(tf.to_float(num_samples)))
kl = tf.reduce_mean(tf.reshape(kl, [num_samples, batch_size]), axis=0)
log_weights = tf.transpose(log_weights, perm=[0, 2, 1])
return log_p_hat, kl, log_weights, log_ess, resampled
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment