Commit 4f9d1024 authored by Chris Shallue's avatar Chris Shallue
Browse files

Open source the image-to-text model based on the "Show and Tell" paper.

parent 54886315
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow_models.im2txt.ops.image_embedding."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from im2txt.ops import image_embedding
class InceptionV3Test(tf.test.TestCase):
def setUp(self):
super(InceptionV3Test, self).setUp()
batch_size = 4
height = 299
width = 299
num_channels = 3
self._images = tf.placeholder(tf.float32,
[batch_size, height, width, num_channels])
self._batch_size = batch_size
def _countInceptionParameters(self):
"""Counts the number of parameters in the inception model at top scope."""
counter = {}
for v in tf.all_variables():
name_tokens = v.op.name.split("/")
if name_tokens[0] == "InceptionV3":
name = "InceptionV3/" + name_tokens[1]
num_params = v.get_shape().num_elements()
assert num_params
counter[name] = counter.get(name, 0) + num_params
return counter
def _verifyParameterCounts(self):
"""Verifies the number of parameters in the inception model."""
param_counts = self._countInceptionParameters()
expected_param_counts = {
"InceptionV3/Conv2d_1a_3x3": 960,
"InceptionV3/Conv2d_2a_3x3": 9312,
"InceptionV3/Conv2d_2b_3x3": 18624,
"InceptionV3/Conv2d_3b_1x1": 5360,
"InceptionV3/Conv2d_4a_3x3": 138816,
"InceptionV3/Mixed_5b": 256368,
"InceptionV3/Mixed_5c": 277968,
"InceptionV3/Mixed_5d": 285648,
"InceptionV3/Mixed_6a": 1153920,
"InceptionV3/Mixed_6b": 1298944,
"InceptionV3/Mixed_6c": 1692736,
"InceptionV3/Mixed_6d": 1692736,
"InceptionV3/Mixed_6e": 2143872,
"InceptionV3/Mixed_7a": 1699584,
"InceptionV3/Mixed_7b": 5047872,
"InceptionV3/Mixed_7c": 6080064,
}
self.assertDictEqual(expected_param_counts, param_counts)
def _assertCollectionSize(self, expected_size, collection):
actual_size = len(tf.get_collection(collection))
if expected_size != actual_size:
self.fail("Found %d items in collection %s (expected %d)." %
(actual_size, collection, expected_size))
def testTrainableTrueIsTrainingTrue(self):
embeddings = image_embedding.inception_v3(
self._images, trainable=True, is_training=True)
self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
self._verifyParameterCounts()
self._assertCollectionSize(376, tf.GraphKeys.VARIABLES)
self._assertCollectionSize(188, tf.GraphKeys.TRAINABLE_VARIABLES)
self._assertCollectionSize(188, tf.GraphKeys.UPDATE_OPS)
self._assertCollectionSize(94, tf.GraphKeys.REGULARIZATION_LOSSES)
self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
def testTrainableTrueIsTrainingFalse(self):
embeddings = image_embedding.inception_v3(
self._images, trainable=True, is_training=False)
self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
self._verifyParameterCounts()
self._assertCollectionSize(376, tf.GraphKeys.VARIABLES)
self._assertCollectionSize(188, tf.GraphKeys.TRAINABLE_VARIABLES)
self._assertCollectionSize(0, tf.GraphKeys.UPDATE_OPS)
self._assertCollectionSize(94, tf.GraphKeys.REGULARIZATION_LOSSES)
self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
def testTrainableFalseIsTrainingTrue(self):
embeddings = image_embedding.inception_v3(
self._images, trainable=False, is_training=True)
self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
self._verifyParameterCounts()
self._assertCollectionSize(376, tf.GraphKeys.VARIABLES)
self._assertCollectionSize(0, tf.GraphKeys.TRAINABLE_VARIABLES)
self._assertCollectionSize(0, tf.GraphKeys.UPDATE_OPS)
self._assertCollectionSize(0, tf.GraphKeys.REGULARIZATION_LOSSES)
self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
def testTrainableFalseIsTrainingFalse(self):
embeddings = image_embedding.inception_v3(
self._images, trainable=False, is_training=False)
self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
self._verifyParameterCounts()
self._assertCollectionSize(376, tf.GraphKeys.VARIABLES)
self._assertCollectionSize(0, tf.GraphKeys.TRAINABLE_VARIABLES)
self._assertCollectionSize(0, tf.GraphKeys.UPDATE_OPS)
self._assertCollectionSize(0, tf.GraphKeys.REGULARIZATION_LOSSES)
self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
if __name__ == "__main__":
tf.test.main()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for image preprocessing."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def distort_image(image, thread_id):
"""Perform random distortions on an image.
Args:
image: A float32 Tensor of shape [height, width, 3] with values in [0, 1).
thread_id: Preprocessing thread id used to select the ordering of color
distortions. There should be a multiple of 2 preprocessing threads.
Returns:
distorted_image: A float32 Tensor of shape [height, width, 3] with values in
[0, 1].
"""
# Randomly flip horizontally.
with tf.name_scope("flip_horizontal", values=[image]):
image = tf.image.random_flip_left_right(image)
# Randomly distort the colors based on thread id.
color_ordering = thread_id % 2
with tf.name_scope("distort_color", values=[image]):
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.032)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
elif color_ordering == 1:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.032)
# The random_* ops do not necessarily clamp.
image = tf.clip_by_value(image, 0.0, 1.0)
return image
def process_image(encoded_image,
is_training,
height,
width,
resize_height=346,
resize_width=346,
thread_id=0,
image_format="jpeg"):
"""Decode an image, resize and apply random distortions.
In training, images are distorted slightly differently depending on thread_id.
Args:
encoded_image: String Tensor containing the image.
is_training: Boolean; whether preprocessing for training or eval.
height: Height of the output image.
width: Width of the output image.
resize_height: If > 0, resize height before crop to final dimensions.
resize_width: If > 0, resize width before crop to final dimensions.
thread_id: Preprocessing thread id used to select the ordering of color
distortions. There should be a multiple of 2 preprocessing threads.
image_format: "jpeg" or "png".
Returns:
A float32 Tensor of shape [height, width, 3] with values in [-1, 1].
Raises:
ValueError: If image_format is invalid.
"""
# Helper function to log an image summary to the visualizer. Summaries are
# only logged in thread 0.
def image_summary(name, image):
if not thread_id:
tf.image_summary(name, tf.expand_dims(image, 0))
# Decode image into a float32 Tensor of shape [?, ?, 3] with values in [0, 1).
with tf.name_scope("decode", values=[encoded_image]):
if image_format == "jpeg":
image = tf.image.decode_jpeg(encoded_image, channels=3)
elif image_format == "png":
image = tf.image.decode_png(encoded_image, channels=3)
else:
raise ValueError("Invalid image format: %s" % image_format)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image_summary("original_image", image)
# Resize image.
assert (resize_height > 0) == (resize_width > 0)
if resize_height:
image = tf.image.resize_images(image,
new_height=resize_height,
new_width=resize_width,
method=tf.image.ResizeMethod.BILINEAR)
# Crop to final dimensions.
if is_training:
image = tf.random_crop(image, [height, width, 3])
else:
# Central crop, assuming resize_height > height, resize_width > width.
image = tf.image.resize_image_with_crop_or_pad(image, height, width)
image_summary("resized_image", image)
# Randomly distort the image.
if is_training:
image = distort_image(image, thread_id)
image_summary("final_image", image)
# Rescale to [-1,1] instead of [0, 1]
image = tf.sub(image, 0.5)
image = tf.mul(image, 2.0)
return image
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def parse_sequence_example(serialized, image_feature, caption_feature):
"""Parses a tensorflow.SequenceExample into an image and caption.
Args:
serialized: A scalar string Tensor; a single serialized SequenceExample.
image_feature: Name of SequenceExample context feature containing image
data.
caption_feature: Name of SequenceExample feature list containing integer
captions.
Returns:
encoded_image: A scalar string Tensor containing a JPEG encoded image.
caption: A 1-D uint64 Tensor with dynamically specified length.
"""
context, sequence = tf.parse_single_sequence_example(
serialized,
context_features={
image_feature: tf.FixedLenFeature([], dtype=tf.string)
},
sequence_features={
caption_feature: tf.FixedLenSequenceFeature([], dtype=tf.int64),
})
encoded_image = context[image_feature]
caption = sequence[caption_feature]
return encoded_image, caption
def prefetch_input_data(reader,
file_pattern,
is_training,
batch_size,
values_per_shard,
input_queue_capacity_factor=16,
num_reader_threads=1,
shard_queue_name="filename_queue",
value_queue_name="input_queue"):
"""Prefetches string values from disk into an input queue.
In training the capacity of the queue is important because a larger queue
means better mixing of training examples between shards. The minimum number of
values kept in the queue is values_per_shard * input_queue_capacity_factor,
where input_queue_memory factor should be chosen to trade-off better mixing
with memory usage.
Args:
reader: Instance of tf.ReaderBase.
file_pattern: Comma-separated list of file patterns (e.g.
/tmp/train_data-?????-of-00100).
is_training: Boolean; whether prefetching for training or eval.
batch_size: Model batch size used to determine queue capacity.
values_per_shard: Approximate number of values per shard.
input_queue_capacity_factor: Minimum number of values to keep in the queue
in multiples of values_per_shard. See comments above.
num_reader_threads: Number of reader threads to fill the queue.
shard_queue_name: Name for the shards filename queue.
value_queue_name: Name for the values input queue.
Returns:
A Queue containing prefetched string values.
"""
data_files = []
for pattern in file_pattern.split(","):
data_files.extend(tf.gfile.Glob(pattern))
if not data_files:
tf.logging.fatal("Found no input files matching %s", file_pattern)
else:
tf.logging.info("Prefetching values from %d files matching %s",
len(data_files), file_pattern)
if is_training:
filename_queue = tf.train.string_input_producer(
data_files, shuffle=True, capacity=16, name=shard_queue_name)
min_queue_examples = values_per_shard * input_queue_capacity_factor
capacity = min_queue_examples + 100 * batch_size
values_queue = tf.RandomShuffleQueue(
capacity=capacity,
min_after_dequeue=min_queue_examples,
dtypes=[tf.string],
name="random_" + value_queue_name)
else:
filename_queue = tf.train.string_input_producer(
data_files, shuffle=False, capacity=1, name=shard_queue_name)
capacity = values_per_shard + 3 * batch_size
values_queue = tf.FIFOQueue(
capacity=capacity, dtypes=[tf.string], name="fifo_" + value_queue_name)
enqueue_ops = []
for _ in range(num_reader_threads):
_, value = reader.read(filename_queue)
enqueue_ops.append(values_queue.enqueue([value]))
tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(
values_queue, enqueue_ops))
tf.scalar_summary(
"queue/%s/fraction_of_%d_full" % (values_queue.name, capacity),
tf.cast(values_queue.size(), tf.float32) * (1. / capacity))
return values_queue
def batch_with_dynamic_pad(images_and_captions,
batch_size,
queue_capacity,
add_summaries=True):
"""Batches input images and captions.
This function splits the caption into an input sequence and a target sequence,
where the target sequence is the input sequence right-shifted by 1. Input and
target sequences are batched and padded up to the maximum length of sequences
in the batch. A mask is created to distinguish real words from padding words.
Example:
Actual captions in the batch ('-' denotes padded character):
[
[ 1 2 5 4 5 ],
[ 1 2 3 4 - ],
[ 1 2 3 - - ],
]
input_seqs:
[
[ 1 2 3 4 ],
[ 1 2 3 - ],
[ 1 2 - - ],
]
target_seqs:
[
[ 2 3 4 5 ],
[ 2 3 4 - ],
[ 2 3 - - ],
]
mask:
[
[ 1 1 1 1 ],
[ 1 1 1 0 ],
[ 1 1 0 0 ],
]
Args:
images_and_captions: A list of pairs [image, caption], where image is a
Tensor of shape [height, width, channels] and caption is a 1-D Tensor of
any length. Each pair will be processed and added to the queue in a
separate thread.
batch_size: Batch size.
queue_capacity: Queue capacity.
add_summaries: If true, add caption length summaries.
Returns:
images: A Tensor of shape [batch_size, height, width, channels].
input_seqs: An int32 Tensor of shape [batch_size, padded_length].
target_seqs: An int32 Tensor of shape [batch_size, padded_length].
mask: An int32 0/1 Tensor of shape [batch_size, padded_length].
"""
enqueue_list = []
for image, caption in images_and_captions:
caption_length = tf.shape(caption)[0]
input_length = tf.expand_dims(tf.sub(caption_length, 1), 0)
input_seq = tf.slice(caption, [0], input_length)
target_seq = tf.slice(caption, [1], input_length)
indicator = tf.ones(input_length, dtype=tf.int32)
enqueue_list.append([image, input_seq, target_seq, indicator])
images, input_seqs, target_seqs, mask = tf.train.batch_join(
enqueue_list,
batch_size=batch_size,
capacity=queue_capacity,
dynamic_pad=True,
name="batch_and_pad")
if add_summaries:
lengths = tf.add(tf.reduce_sum(mask, 1), 1)
tf.scalar_summary("caption_length/batch_min", tf.reduce_min(lengths))
tf.scalar_summary("caption_length/batch_max", tf.reduce_max(lengths))
tf.scalar_summary("caption_length/batch_mean", tf.reduce_mean(lengths))
return images, input_seqs, target_seqs, mask
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Generate captions for images using default beam search parameters."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
import tensorflow as tf
from im2txt import configuration
from im2txt import inference_wrapper
from im2txt.inference_utils import caption_generator
from im2txt.inference_utils import vocabulary
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("checkpoint_path", "",
"Model checkpoint file or directory containing a "
"model checkpoint file.")
tf.flags.DEFINE_string("vocab_file", "", "Text file containing the vocabulary.")
tf.flags.DEFINE_string("input_files", "",
"File pattern or comma-separated list of file patterns "
"of image files.")
def main(_):
# Build the inference graph.
g = tf.Graph()
with g.as_default():
model = inference_wrapper.InferenceWrapper()
restore_fn = model.build_graph_from_config(configuration.ModelConfig(),
FLAGS.checkpoint_path)
g.finalize()
# Create the vocabulary.
vocab = vocabulary.Vocabulary(FLAGS.vocab_file)
filenames = []
for file_pattern in FLAGS.input_files.split(","):
filenames.extend(tf.gfile.Glob(file_pattern))
tf.logging.info("Running caption generation on %d files matching %s",
len(filenames), FLAGS.input_files)
with tf.Session(graph=g) as sess:
# Load the model from checkpoint.
restore_fn(sess)
# Prepare the caption generator. Here we are implicitly using the default
# beam search parameters. See caption_generator.py for a description of the
# available beam search parameters.
generator = caption_generator.CaptionGenerator(model, vocab)
for filename in filenames:
with tf.gfile.GFile(filename, "r") as f:
image = f.read()
captions = generator.beam_search(sess, image)
print("Captions for image %s:" % os.path.basename(filename))
for i, caption in enumerate(captions):
# Ignore begin and end words.
sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]]
sentence = " ".join(sentence)
print(" %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob)))
if __name__ == "__main__":
tf.app.run()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
"Show and Tell: A Neural Image Caption Generator"
Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from im2txt.ops import image_embedding
from im2txt.ops import image_processing
from im2txt.ops import inputs as input_ops
class ShowAndTellModel(object):
"""Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
"Show and Tell: A Neural Image Caption Generator"
Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
"""
def __init__(self, config, mode, train_inception=False):
"""Basic setup.
Args:
config: Object containing configuration parameters.
mode: "train", "eval" or "inference".
train_inception: Whether the inception submodel variables are trainable.
"""
assert mode in ["train", "eval", "inference"]
self.config = config
self.mode = mode
self.train_inception = train_inception
# Reader for the input data.
self.reader = tf.TFRecordReader()
# To match the "Show and Tell" paper we initialize all variables with a
# random uniform initializer.
self.initializer = tf.random_uniform_initializer(
minval=-self.config.initializer_scale,
maxval=self.config.initializer_scale)
# A float32 Tensor with shape [batch_size, height, width, channels].
self.images = None
# An int32 Tensor with shape [batch_size, padded_length].
self.input_seqs = None
# An int32 Tensor with shape [batch_size, padded_length].
self.target_seqs = None
# An int32 0/1 Tensor with shape [batch_size, padded_length].
self.input_mask = None
# A float32 Tensor with shape [batch_size, embedding_size].
self.image_embeddings = None
# A float32 Tensor with shape [batch_size, padded_length, embedding_size].
self.seq_embeddings = None
# A float32 scalar Tensor; the total loss for the trainer to optimize.
self.total_loss = None
# A float32 Tensor with shape [batch_size * padded_length].
self.target_cross_entropy_losses = None
# A float32 Tensor with shape [batch_size * padded_length].
self.target_cross_entropy_loss_weights = None
# Collection of variables from the inception submodel.
self.inception_variables = []
# Function to restore the inception submodel from checkpoint.
self.init_fn = None
# Global step Tensor.
self.global_step = None
def is_training(self):
"""Returns true if the model is built for training mode."""
return self.mode == "train"
def process_image(self, encoded_image, thread_id=0):
"""Decodes and processes an image string.
Args:
encoded_image: A scalar string Tensor; the encoded image.
thread_id: Preprocessing thread id used to select the ordering of color
distortions.
Returns:
A float32 Tensor of shape [height, width, 3]; the processed image.
"""
return image_processing.process_image(encoded_image,
is_training=self.is_training(),
height=self.config.image_height,
width=self.config.image_width,
thread_id=thread_id,
image_format=self.config.image_format)
def build_inputs(self):
"""Input prefetching, preprocessing and batching.
Outputs:
self.images
self.input_seqs
self.target_seqs (training and eval only)
self.input_mask (training and eval only)
"""
if self.mode == "inference":
# In inference mode, images and inputs are fed via placeholders.
image_feed = tf.placeholder(dtype=tf.string, shape=[], name="image_feed")
input_feed = tf.placeholder(dtype=tf.int64,
shape=[None], # batch_size
name="input_feed")
# Process image and insert batch dimensions.
images = tf.expand_dims(self.process_image(image_feed), 0)
input_seqs = tf.expand_dims(input_feed, 1)
# No target sequences or input mask in inference mode.
target_seqs = None
input_mask = None
else:
# Prefetch serialized SequenceExample protos.
input_queue = input_ops.prefetch_input_data(
self.reader,
self.config.input_file_pattern,
is_training=self.is_training(),
batch_size=self.config.batch_size,
values_per_shard=self.config.values_per_input_shard,
input_queue_capacity_factor=self.config.input_queue_capacity_factor,
num_reader_threads=self.config.num_input_reader_threads)
# Image processing and random distortion. Split across multiple threads
# with each thread applying a slightly different distortion.
assert self.config.num_preprocess_threads % 2 == 0
images_and_captions = []
for thread_id in range(self.config.num_preprocess_threads):
serialized_sequence_example = input_queue.dequeue()
encoded_image, caption = input_ops.parse_sequence_example(
serialized_sequence_example,
image_feature=self.config.image_feature_name,
caption_feature=self.config.caption_feature_name)
image = self.process_image(encoded_image, thread_id=thread_id)
images_and_captions.append([image, caption])
# Batch inputs.
queue_capacity = (2 * self.config.num_preprocess_threads *
self.config.batch_size)
images, input_seqs, target_seqs, input_mask = (
input_ops.batch_with_dynamic_pad(images_and_captions,
batch_size=self.config.batch_size,
queue_capacity=queue_capacity))
self.images = images
self.input_seqs = input_seqs
self.target_seqs = target_seqs
self.input_mask = input_mask
def build_image_embeddings(self):
"""Builds the image model subgraph and generates image embeddings.
Inputs:
self.images
Outputs:
self.image_embeddings
"""
inception_output = image_embedding.inception_v3(
self.images,
trainable=self.train_inception,
is_training=self.is_training())
self.inception_variables = tf.get_collection(
tf.GraphKeys.VARIABLES, scope="InceptionV3")
# Map inception output into embedding space.
with tf.variable_scope("image_embedding") as scope:
image_embeddings = tf.contrib.layers.fully_connected(
inputs=inception_output,
num_outputs=self.config.embedding_size,
activation_fn=None,
weights_initializer=self.initializer,
biases_initializer=None,
scope=scope)
# Save the embedding size in the graph.
tf.constant(self.config.embedding_size, name="embedding_size")
self.image_embeddings = image_embeddings
def build_seq_embeddings(self):
"""Builds the input sequence embeddings.
Inputs:
self.input_seqs
Outputs:
self.seq_embeddings
"""
with tf.variable_scope("seq_embedding"), tf.device("/cpu:0"):
embedding_map = tf.get_variable(
name="map",
shape=[self.config.vocab_size, self.config.embedding_size],
initializer=self.initializer)
seq_embeddings = tf.nn.embedding_lookup(embedding_map, self.input_seqs)
self.seq_embeddings = seq_embeddings
def build_model(self):
"""Builds the model.
Inputs:
self.image_embeddings
self.seq_embeddings
self.target_seqs (training and eval only)
self.input_mask (training and eval only)
Outputs:
self.total_loss (training and eval only)
self.target_cross_entropy_losses (training and eval only)
self.target_cross_entropy_loss_weights (training and eval only)
"""
# This LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the
# modified LSTM in the "Show and Tell" paper has no biases and outputs
# new_c * sigmoid(o).
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(
num_units=self.config.num_lstm_units, state_is_tuple=True)
if self.mode == "train":
lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
lstm_cell,
input_keep_prob=self.config.lstm_dropout_keep_prob,
output_keep_prob=self.config.lstm_dropout_keep_prob)
with tf.variable_scope("lstm", initializer=self.initializer) as lstm_scope:
# Feed the image embeddings to set the initial LSTM state.
zero_state = lstm_cell.zero_state(
batch_size=self.image_embeddings.get_shape()[0], dtype=tf.float32)
_, initial_state = lstm_cell(self.image_embeddings, zero_state)
# Allow the LSTM variables to be reused.
lstm_scope.reuse_variables()
if self.mode == "inference":
# In inference mode, use concatenated states for convenient feeding and
# fetching.
tf.concat(1, initial_state, name="initial_state")
# Placeholder for feeding a batch of concatenated states.
state_feed = tf.placeholder(dtype=tf.float32,
shape=[None, sum(lstm_cell.state_size)],
name="state_feed")
state_tuple = tf.split(1, 2, state_feed)
# Run a single LSTM step.
lstm_outputs, state_tuple = lstm_cell(
inputs=tf.squeeze(self.seq_embeddings, squeeze_dims=[1]),
state=state_tuple)
# Concatentate the resulting state.
tf.concat(1, state_tuple, name="state")
else:
# Run the batch of sequence embeddings through the LSTM.
sequence_length = tf.reduce_sum(self.input_mask, 1)
lstm_outputs, _ = tf.nn.dynamic_rnn(cell=lstm_cell,
inputs=self.seq_embeddings,
sequence_length=sequence_length,
initial_state=initial_state,
dtype=tf.float32,
scope=lstm_scope)
# Stack batches vertically.
lstm_outputs = tf.reshape(lstm_outputs, [-1, lstm_cell.output_size])
with tf.variable_scope("logits") as logits_scope:
logits = tf.contrib.layers.fully_connected(
inputs=lstm_outputs,
num_outputs=self.config.vocab_size,
activation_fn=None,
weights_initializer=self.initializer,
scope=logits_scope)
if self.mode == "inference":
tf.nn.softmax(logits, name="softmax")
else:
targets = tf.reshape(self.target_seqs, [-1])
weights = tf.to_float(tf.reshape(self.input_mask, [-1]))
# Compute losses.
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, targets)
batch_loss = tf.div(tf.reduce_sum(tf.mul(losses, weights)),
tf.reduce_sum(weights),
name="batch_loss")
tf.contrib.losses.add_loss(batch_loss)
total_loss = tf.contrib.losses.get_total_loss()
# Add summaries.
tf.scalar_summary("batch_loss", batch_loss)
tf.scalar_summary("total_loss", total_loss)
for var in tf.trainable_variables():
tf.histogram_summary(var.op.name, var)
self.total_loss = total_loss
self.target_cross_entropy_losses = losses # Used in evaluation.
self.target_cross_entropy_loss_weights = weights # Used in evaluation.
def setup_inception_initializer(self):
"""Sets up the function to restore inception variables from checkpoint."""
if self.mode != "inference":
# Restore inception variables only.
saver = tf.train.Saver(self.inception_variables)
def restore_fn(sess):
tf.logging.info("Restoring Inception variables from checkpoint file %s",
self.config.inception_checkpoint_file)
saver.restore(sess, self.config.inception_checkpoint_file)
self.init_fn = restore_fn
def setup_global_step(self):
"""Sets up the global step Tensor."""
global_step = tf.Variable(
initial_value=0,
name="global_step",
trainable=False,
collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.VARIABLES])
self.global_step = global_step
def setup_saver(self):
"""Sets up the Saver for loading and saving model checkpoints."""
self.saver = tf.train.Saver(
max_to_keep=self.config.max_checkpoints_to_keep,
keep_checkpoint_every_n_hours=self.config.keep_checkpoint_every_n_hours)
def build(self):
"""Creates all ops for training and evaluation."""
self.build_inputs()
self.build_image_embeddings()
self.build_seq_embeddings()
self.build_model()
self.setup_inception_initializer()
self.setup_global_step()
self.setup_saver()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow_models.im2txt.show_and_tell_model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from im2txt import configuration
from im2txt import show_and_tell_model
class ShowAndTellModel(show_and_tell_model.ShowAndTellModel):
"""Subclass of ShowAndTellModel without the disk I/O."""
def build_inputs(self):
if self.mode == "inference":
# Inference mode doesn't read from disk, so defer to parent.
return super(ShowAndTellModel, self).build_inputs()
else:
# Replace disk I/O with random Tensors.
self.images = tf.random_uniform(
shape=[self.config.batch_size, self.config.image_height,
self.config.image_width, 3],
minval=-1,
maxval=1)
self.input_seqs = tf.random_uniform(
[self.config.batch_size, 15],
minval=0,
maxval=self.config.vocab_size,
dtype=tf.int64)
self.target_seqs = tf.random_uniform(
[self.config.batch_size, 15],
minval=0,
maxval=self.config.vocab_size,
dtype=tf.int64)
self.input_mask = tf.ones_like(self.input_seqs)
class ShowAndTellModelTest(tf.test.TestCase):
def setUp(self):
super(ShowAndTellModelTest, self).setUp()
self._model_config = configuration.ModelConfig()
def _countModelParameters(self):
"""Counts the number of parameters in the model at top level scope."""
counter = {}
for v in tf.all_variables():
name = v.op.name.split("/")[0]
num_params = v.get_shape().num_elements()
assert num_params
counter[name] = counter.get(name, 0) + num_params
return counter
def _checkModelParameters(self):
"""Verifies the number of parameters in the model."""
param_counts = self._countModelParameters()
expected_param_counts = {
"InceptionV3": 21802784,
# inception_output_size * embedding_size
"image_embedding": 1048576,
# vocab_size * embedding_size
"seq_embedding": 6144000,
# (embedding_size + num_lstm_units + 1) * 4 * num_lstm_units
"lstm": 2099200,
# (num_lstm_units + 1) * vocab_size
"logits": 6156000,
"global_step": 1,
}
self.assertDictEqual(expected_param_counts, param_counts)
def _checkOutputs(self, expected_shapes, feed_dict=None):
"""Verifies that the model produces expected outputs.
Args:
expected_shapes: A dict mapping Tensor or Tensor name to expected output
shape.
feed_dict: Values of Tensors to feed into Session.run().
"""
fetches = expected_shapes.keys()
with self.test_session() as sess:
sess.run(tf.initialize_all_variables())
outputs = sess.run(fetches, feed_dict)
for index, output in enumerate(outputs):
tensor = fetches[index]
expected = expected_shapes[tensor]
actual = output.shape
if expected != actual:
self.fail("Tensor %s has shape %s (expected %s)." %
(tensor, actual, expected))
def testBuildForTraining(self):
model = ShowAndTellModel(self._model_config, mode="train")
model.build()
self._checkModelParameters()
expected_shapes = {
# [batch_size, image_height, image_width, 3]
model.images: (32, 299, 299, 3),
# [batch_size, sequence_length]
model.input_seqs: (32, 15),
# [batch_size, sequence_length]
model.target_seqs: (32, 15),
# [batch_size, sequence_length]
model.input_mask: (32, 15),
# [batch_size, embedding_size]
model.image_embeddings: (32, 512),
# [batch_size, sequence_length, embedding_size]
model.seq_embeddings: (32, 15, 512),
# Scalar
model.total_loss: (),
# [batch_size * sequence_length]
model.target_cross_entropy_losses: (480,),
# [batch_size * sequence_length]
model.target_cross_entropy_loss_weights: (480,),
}
self._checkOutputs(expected_shapes)
def testBuildForEval(self):
model = ShowAndTellModel(self._model_config, mode="eval")
model.build()
self._checkModelParameters()
expected_shapes = {
# [batch_size, image_height, image_width, 3]
model.images: (32, 299, 299, 3),
# [batch_size, sequence_length]
model.input_seqs: (32, 15),
# [batch_size, sequence_length]
model.target_seqs: (32, 15),
# [batch_size, sequence_length]
model.input_mask: (32, 15),
# [batch_size, embedding_size]
model.image_embeddings: (32, 512),
# [batch_size, sequence_length, embedding_size]
model.seq_embeddings: (32, 15, 512),
# Scalar
model.total_loss: (),
# [batch_size * sequence_length]
model.target_cross_entropy_losses: (480,),
# [batch_size * sequence_length]
model.target_cross_entropy_loss_weights: (480,),
}
self._checkOutputs(expected_shapes)
def testBuildForInference(self):
model = ShowAndTellModel(self._model_config, mode="inference")
model.build()
self._checkModelParameters()
# Test feeding an image to get the initial LSTM state.
images_feed = np.random.rand(1, 299, 299, 3)
feed_dict = {model.images: images_feed}
expected_shapes = {
# [batch_size, embedding_size]
model.image_embeddings: (1, 512),
# [batch_size, 2 * num_lstm_units]
"lstm/initial_state:0": (1, 1024),
}
self._checkOutputs(expected_shapes, feed_dict)
# Test feeding a batch of inputs and LSTM states to get softmax output and
# LSTM states.
input_feed = np.random.randint(0, 10, size=3)
state_feed = np.random.rand(3, 1024)
feed_dict = {"input_feed:0": input_feed, "lstm/state_feed:0": state_feed}
expected_shapes = {
# [batch_size, 2 * num_lstm_units]
"lstm/state:0": (3, 1024),
# [batch_size, vocab_size]
"softmax:0": (3, 12000),
}
self._checkOutputs(expected_shapes, feed_dict)
if __name__ == "__main__":
tf.test.main()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train the model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from im2txt import configuration
from im2txt import show_and_tell_model
FLAGS = tf.app.flags.FLAGS
tf.flags.DEFINE_string("input_file_pattern", "",
"File pattern of sharded TFRecord input files.")
tf.flags.DEFINE_string("inception_checkpoint_file", "",
"Path to a pretrained inception_v3 model.")
tf.flags.DEFINE_string("train_dir", "",
"Directory for saving and loading model checkpoints.")
tf.flags.DEFINE_boolean("train_inception", False,
"Whether to train inception submodel variables.")
tf.flags.DEFINE_integer("number_of_steps", 1000000, "Number of training steps.")
tf.flags.DEFINE_integer("log_every_n_steps", 1,
"Frequency at which loss and global step are logged.")
tf.logging.set_verbosity(tf.logging.INFO)
def main(unused_argv):
assert FLAGS.input_file_pattern, "--input_file_pattern is required"
assert FLAGS.train_dir, "--train_dir is required"
model_config = configuration.ModelConfig()
model_config.input_file_pattern = FLAGS.input_file_pattern
model_config.inception_checkpoint_file = FLAGS.inception_checkpoint_file
training_config = configuration.TrainingConfig()
# Create training directory.
train_dir = FLAGS.train_dir
if not tf.gfile.IsDirectory(train_dir):
tf.logging.info("Creating training directory: %s", train_dir)
tf.gfile.MakeDirs(train_dir)
# Build the TensorFlow graph.
g = tf.Graph()
with g.as_default():
# Build the model.
model = show_and_tell_model.ShowAndTellModel(
model_config, mode="train", train_inception=FLAGS.train_inception)
model.build()
# Set up the learning rate.
learning_rate_decay_fn = None
if FLAGS.train_inception:
learning_rate = tf.constant(training_config.train_inception_learning_rate)
else:
learning_rate = tf.constant(training_config.initial_learning_rate)
if training_config.learning_rate_decay_factor > 0:
num_batches_per_epoch = (training_config.num_examples_per_epoch /
model_config.batch_size)
decay_steps = int(num_batches_per_epoch *
training_config.num_epochs_per_decay)
def _learning_rate_decay_fn(learning_rate, global_step):
return tf.train.exponential_decay(
learning_rate,
global_step,
decay_steps=decay_steps,
decay_rate=training_config.learning_rate_decay_factor,
staircase=True)
learning_rate_decay_fn = _learning_rate_decay_fn
# Set up the training ops.
train_op = tf.contrib.layers.optimize_loss(
loss=model.total_loss,
global_step=model.global_step,
learning_rate=learning_rate,
optimizer=training_config.optimizer,
clip_gradients=training_config.clip_gradients,
learning_rate_decay_fn=learning_rate_decay_fn)
# Run training.
tf.contrib.slim.learning.train(
train_op,
train_dir,
log_every_n_steps=FLAGS.log_every_n_steps,
graph=g,
global_step=model.global_step,
number_of_steps=FLAGS.number_of_steps,
init_fn=model.init_fn,
saver=model.saver)
if __name__ == "__main__":
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment