Commit 4f9d1024 authored by Chris Shallue's avatar Chris Shallue
Browse files

Open source the image-to-text model based on the "Show and Tell" paper.

parent 54886315
...@@ -19,3 +19,4 @@ To propose a model for inclusion please submit a pull request. ...@@ -19,3 +19,4 @@ To propose a model for inclusion please submit a pull request.
- [syntaxnet](syntaxnet) -- neural models of natural language syntax - [syntaxnet](syntaxnet) -- neural models of natural language syntax
- [textsum](textsum) -- sequence-to-sequence with attention model for text summarization. - [textsum](textsum) -- sequence-to-sequence with attention model for text summarization.
- [transformer](transformer) -- spatial transformer network, which allows the spatial manipulation of data within the network - [transformer](transformer) -- spatial transformer network, which allows the spatial manipulation of data within the network
- [im2txt](im2txt) -- image-to-text neural network for image captioning.
/bazel-bin
/bazel-ci_build-cache
/bazel-genfiles
/bazel-out
/bazel-im2txt
/bazel-testlogs
/bazel-tf
# Show and Tell: A Neural Image Caption Generator
A TensorFlow implementation of the image-to-text model described in
*Show and Tell: A Neural Image Caption Generator*
Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan.
http://arxiv.org/abs/1411.4555
## Contact
***Author:*** Chris Shallue (shallue@google.com).
***Pull requests and issues:*** @cshallue.
## Contents
* [Model Overview](#model-overview)
* [Introduction](#introduction)
* [Architecture](#architecture)
* [Getting Started](#getting-started)
* [A Note on Hardware and Training Time](#a-note-on-hardware-and-training-time)
* [Install Required Packages](#install-required-packages)
* [Prepare the Training Data](#prepare-the-training-data)
* [Download the Inception v3 Checkpoint](#download-the-inception-v3-checkpoint)
* [Training a Model](#training-a-model)
* [Initial Training](#initial-training)
* [Fine Tune the Inception v3 Model](#fine-tune-the-inception-v3-model)
* [Generating Captions](#generating-captions)
## Model Overview
### Introduction
The *Show and Tell* model is a deep neural network that learns how to describe
the content of images. For example:
<center>
![Example captions](g3doc/example_captions.jpg)
</center>
### Architecture
The *Show and Tell* model is an example of an *encoder-decoder* neural network.
It works by first "encoding" an image into a fixed-length vector representation,
and then "decoding" the representation into a natural language description.
The image encoder is a deep convolutional neural network. This type of
network is widely used for image tasks and is currently state-of-the-art for
object recognition and detection. Our particular choice of network is the
[*Inception v3*](http://arxiv.org/abs/1512.00567) image recognition model
pretrained on the
[ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) image
classification dataset.
The decoder is a long short-term memory (LSTM) network. This type of network is
commonly used for sequence modeling tasks such as language modeling and machine
translation. In the *Show and Tell* model, the LSTM network is trained as a
language model conditioned on the image encoding.
Words in the captions are represented with an embedding model. Each word in the
vocabulary is associated with a fixed-length vector representation that is
learned during training.
The following diagram illustrates the model architecture.
<center>
![Show and Tell Architecture](g3doc/show_and_tell_architecture.png)
</center>
In this diagram, $$\{ s_0, s_1, ..., s_{N-1} \}$$ are the words of the caption
and $$\{ w_e s_0, w_e s_1, ..., w_e s_{N-1} \}$$ are their corresponding word
embedding vectors. The outputs $$\{ p_1, p_2, ..., p_N \}$$ of the LSTM are
probability distributions generated by the model for the next word in the
sentence. The terms $$\{ \log p_1(s_1), \log p_2(s_2), ..., \log p_N(s_N) \}$$
are the log-likelihoods of the correct word at each step; the negated sum of
these terms is the minimization objective of the model.
During the first phase of training the parameters of the *Inception v3* model
are kept fixed: it is simply a static image encoder function. A single trainable
layer is added on top of the *Inception v3* model to transform the image
embedding into the word embedding vector space. The model is trained with
respect to the parameters of the word embeddings, the parameters of the layer on
top of *Inception v3* and the parameters of the LSTM. In the second phase of
training, all parameters - including the parameters of *Inception v3* - are
trained to jointly fine-tune the image encoder and the LSTM.
Given a trained model and an image we use *beam search* to generate captions for
that image. Captions are generated word-by-word, where at each step $$t$$ we use
the set of sentences already generated with length $$t-1$$ to generate a new set
of sentences with length $$t$$. We keep only the top $$k$$ candidates at each
step, where the hyperparameter $$k$$ is called the *beam size*. We have found
the best performance with $$k=3$$.
## Getting Started
### A Note on Hardware and Training Time
The time required to train the *Show and Tell* model depends on your specific
hardware and computational capacity. In this guide we assume you will be running
training on a single machine with a GPU. In our experience on an NVIDIA Tesla
K20m GPU the initial training phase takes 1-2 weeks. The second training phase
may take several additional weeks to achieve peak performance (but you can stop
this phase early and still get reasonable results).
It is possible to achieve a speed-up by implementing distributed training across
a cluster of machines with GPUs, but that is not covered in this guide.
Whilst it is possible to run this code on a CPU, beware that this may be
approximately 10 times slower.
### Install Required Packages
First ensure that you have installed the following required packages:
* **Bazel** ([instructions](http://bazel.io/docs/install.html)).
* **TensorFlow** ([instructions](https://www.tensorflow.org/versions/r0.10/get_started/os_setup.html)).
* **NumPy** ([instructions](http://www.scipy.org/install.html)).
* **Natural Language Toolkit (NLTK)**:
* First install NLTK ([instructions](http://www.nltk.org/install.html)).
* Then install the NLTK data ([instructions](http://www.nltk.org/data.html)).
### Prepare the Training Data
To train the model you will need to provide training data in native TFRecord
format. The TFRecord format consists of a set of sharded files containing
serialized `tf.SequenceExample` protocol buffers. Each `tf.SequenceExample`
proto contains an image (JPEG format), a caption and metadata such as the image
id.
Each caption is a list of words. During preprocessing, a dictionary is created
that assigns each word in the vocabulary to an integer-valued id. Each caption
is encoded as a list of integer word ids in the `tf.SequenceExample` protos.
We have provided a script to download and preprocess the [MSCOCO]
(http://mscoco.org/) image captioning data set into this format. Downloading
and preprocessing the data may take several hours depending on your network and
computer speed. Please be patient.
Before running the script, ensure that your hard disk has at least 150GB of
available space for storing the downloaded and processed data.
```shell
# Location to save the MSCOCO data.
MSCOCO_DIR="${HOME}/im2txt/data/mscoco"
# Build the preprocessing script.
bazel build im2txt/download_and_preprocess_mscoco
# Run the preprocessing script.
bazel-bin/im2txt/download_and_preprocess_mscoco "${MSCOCO_DIR}"
```
The final line of the output should read:
```
2016-09-01 16:47:47.296630: Finished processing all 20267 image-caption pairs in data set 'test'.
```
When the script finishes you will find 256 training, 4 validation and 8 testing
files in `DATA_DIR`. The files will match the patterns `train-?????-of-00256`,
`val-?????-of-00004` and `test-?????-of-00008`, respectively.
### Download the Inception v3 Checkpoint
The *Show and Tell* model requires a pretrained *Inception v3* checkpoint file
to initialize the parameters of its image encoder submodel.
This checkpoint file is provided by the
[TensorFlow-Slim image classification library](https://github.com/tensorflow/models/tree/master/slim#tensorflow-slim-image-classification-library)
which provides a suite of pre-trained image classification models. You can read
more about the models provided by the library
[here](https://github.com/tensorflow/models/tree/master/slim#pre-trained-models).
Run the following commands to download the *Inception v3* checkpoint.
```shell
# Location to save the Inception v3 checkpoint.
INCEPTION_DIR="${HOME}/im2txt/data"
mkdir -p ${INCEPTION_DIR}
wget "http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz"
tar -xvf "inception_v3_2016_08_28.tar.gz" -C ${INCEPTION_DIR}
rm "inception_v3_2016_08_28.tar.gz"
```
Note that the *Inception v3* checkpoint will only be used for initializing the
parameters of the *Show and Tell* model. Once the *Show and Tell* model starts
training it will save its own checkpoint files containing the values of all its
parameters (including copies of the *Inception v3* parameters). If training is
stopped and restarted, the parameter values will be restored from the latest
*Show and Tell* checkpoint and the *Inception v3* checkpoint will be ignored. In
other words, the *Inception v3* checkpoint is only used in the 0-th global step
(initialization) of training the *Show and Tell* model.
## Training a Model
### Initial Training
Run the training script.
```shell
# Directory containing preprocessed MSCOCO data.
MSCOCO_DIR="${HOME}/im2txt/data/mscoco"
# Inception v3 checkpoint file.
INCEPTION_CHECKPOINT="${HOME}/im2txt/data/inception_v3.ckpt"
# Directory to save the model.
MODEL_DIR="${HOME}/im2txt/model"
# Build the model.
bazel build -c opt --config=cuda im2txt/...
# Run the training script.
bazel-bin/im2txt/train \
--input_file_pattern="${MSCOCO_DIR}/train-?????-of-00256" \
--inception_checkpoint_file="${INCEPTION_CHECKPOINT}" \
--train_dir="${MODEL_DIR}/train" \
--train_inception=false \
--number_of_steps=1000000
```
Run the evaluation script in a separate process. This will log evaluation
metrics to TensorBoard which allows training progress to be monitored in
real-time.
Note that you may run out of memory if you run the evaluation script on the same
GPU as the training script. You can run the command
`export CUDA_VISIBLE_DEVICES=""` to force the evaluation script to run on CPU.
If evaluation runs too slowly on CPU, you can decrease the value of
`--num_eval_examples`.
```shell
MSCOCO_DIR="${HOME}/im2txt/data/mscoco"
MODEL_DIR="${HOME}/im2txt/model"
# Ignore GPU devices (only necessary if your GPU is currently memory
# constrained, for example, by running the training script).
export CUDA_VISIBLE_DEVICES=""
# Run the evaluation script. This will run in a loop, periodically loading the
# latest model checkpoint file and computing evaluation metrics.
bazel-bin/im2txt/evaluate \
--input_file_pattern="${MSCOCO_DIR}/val-?????-of-00004" \
--checkpoint_dir="${MODEL_DIR}/train" \
--eval_dir="${MODEL_DIR}/eval"
```
Run a TensorBoard server in a separate process for real-time monitoring of
training progress and evaluation metrics.
```shell
MODEL_DIR="${HOME}/im2txt/model"
# Run a TensorBoard server.
tensorboard --logdir="${MODEL_DIR}"
```
### Fine Tune the Inception v3 Model
Your model will already be able to generate reasonable captions after the first
phase of training. Try it out! (See [Generating Captions]
(#generating-captions)).
You can further improve the performance of the model by running a
second training phase to jointly fine-tune the parameters of the *Inception v3*
image submodel and the LSTM.
```shell
# Restart the training script with --train_inception=true.
bazel-bin/im2txt/train \
--input_file_pattern="${MSCOCO_DIR}/train-?????-of-00256" \
--train_dir="${MODEL_DIR}/train" \
--train_inception=true \
--number_of_steps=3000000 # Additional 2M steps (assuming 1M in initial training).
```
Note that training will proceed much slower now, and the model will continue to
improve by a small amount for a long time. We have found that it will improve
slowly for an additional 2-2.5 million steps before it begins to overfit. This
may take several weeks on a single GPU. If you don't care about absolutely
optimal performance then feel free to halt training sooner by stopping the
training script or passing a smaller value to the flag `--number_of_steps`. Your
model will still work reasonably well.
## Generating Captions
Your trained *Show and Tell* model can generate captions for any JPEG image! The
following command line will generate captions for an image from the test set.
```shell
# Directory containing model checkpoints.
CHECKPOINT_DIR="${HOME}/im2txt/model/train"
# Vocabulary file generated by the preprocessing script.
VOCAB_FILE="${HOME}/im2txt/data/mscoco/word_counts.txt"
# JPEG image file to caption.
IMAGE_FILE="${HOME}/im2txt/data/mscoco/raw-data/val2014/COCO_val2014_000000224477.jpg"
# Build the inference binary.
bazel build -c opt im2txt/run_inference
# Ignore GPU devices (only necessary if your GPU is currently memory
# constrained, for example, by running the training script).
export CUDA_VISIBLE_DEVICES=""
# Run inference to generate captions.
bazel-bin/im2txt/run_inference \
--checkpoint_path=${CHECKPOINT_DIR} \
--vocab_file=${VOCAB_FILE} \
--input_files=${IMAGE_FILE}
```
Example output:
```shell
Captions for image COCO_val2014_000000224477.jpg:
0) a man riding a wave on top of a surfboard . (p=0.040413)
1) a person riding a surf board on a wave (p=0.017452)
2) a man riding a wave on a surfboard in the ocean . (p=0.005743)
```
Note: you may get different results. Some variation between different models is
expected.
Here is the image:
<center>
![Surfer](g3doc/COCO_val2014_000000224477.jpg)
</center>
package(default_visibility = [":internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = [
"//im2txt/...",
],
)
py_binary(
name = "build_mscoco_data",
srcs = [
"data/build_mscoco_data.py",
],
)
sh_binary(
name = "download_and_preprocess_mscoco",
srcs = ["data/download_and_preprocess_mscoco.sh"],
data = [
":build_mscoco_data",
],
)
py_library(
name = "configuration",
srcs = ["configuration.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "show_and_tell_model",
srcs = ["show_and_tell_model.py"],
srcs_version = "PY2AND3",
deps = [
"//im2txt/ops:image_embedding",
"//im2txt/ops:image_processing",
"//im2txt/ops:inputs",
],
)
py_test(
name = "show_and_tell_model_test",
size = "large",
srcs = ["show_and_tell_model_test.py"],
deps = [
":configuration",
":show_and_tell_model",
],
)
py_library(
name = "inference_wrapper",
srcs = ["inference_wrapper.py"],
srcs_version = "PY2AND3",
deps = [
":show_and_tell_model",
"//im2txt/inference_utils:inference_wrapper_base",
],
)
py_binary(
name = "train",
srcs = ["train.py"],
srcs_version = "PY2AND3",
deps = [
":configuration",
":show_and_tell_model",
],
)
py_binary(
name = "evaluate",
srcs = ["evaluate.py"],
srcs_version = "PY2AND3",
deps = [
":configuration",
":show_and_tell_model",
],
)
py_binary(
name = "run_inference",
srcs = ["run_inference.py"],
srcs_version = "PY2AND3",
deps = [
":configuration",
":inference_wrapper",
"//im2txt/inference_utils:caption_generator",
"//im2txt/inference_utils:vocabulary",
],
)
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image-to-text model and training configurations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class ModelConfig(object):
"""Wrapper class for model hyperparameters."""
def __init__(self):
"""Sets the default model hyperparameters."""
# File pattern of sharded TFRecord file containing SequenceExample protos.
# Must be provided in training and evaluation modes.
self.input_file_pattern = None
# Image format ("jpeg" or "png").
self.image_format = "jpeg"
# Approximate number of values per input shard. Used to ensure sufficient
# mixing between shards in training.
self.values_per_input_shard = 2300
# Minimum number of shards to keep in the input queue.
self.input_queue_capacity_factor = 2
# Number of threads for prefetching SequenceExample protos.
self.num_input_reader_threads = 1
# Name of the SequenceExample context feature containing image data.
self.image_feature_name = "image/data"
# Name of the SequenceExample feature list containing integer captions.
self.caption_feature_name = "image/caption_ids"
# Number of unique words in the vocab (plus 1, for <UNK>).
# The default value is larger than the expected actual vocab size to allow
# for differences between tokenizer versions used in preprocessing. There is
# no harm in using a value greater than the actual vocab size, but using a
# value less than the actual vocab size will result in an error.
self.vocab_size = 12000
# Number of threads for image preprocessing. Should be a multiple of 2.
self.num_preprocess_threads = 4
# Batch size.
self.batch_size = 32
# File containing an Inception v3 checkpoint to initialize the variables
# of the Inception model. Must be provided when starting training for the
# first time.
self.inception_checkpoint_file = None
# Dimensions of Inception v3 input images.
self.image_height = 299
self.image_width = 299
# Scale used to initialize model variables.
self.initializer_scale = 0.08
# LSTM input and output dimensionality, respectively.
self.embedding_size = 512
self.num_lstm_units = 512
# If < 1.0, the dropout keep probability applied to LSTM variables.
self.lstm_dropout_keep_prob = 0.7
# How many model checkpoints to keep.
self.max_checkpoints_to_keep = 5
self.keep_checkpoint_every_n_hours = 10000
class TrainingConfig(object):
"""Wrapper class for training hyperparameters."""
def __init__(self):
"""Sets the default training hyperparameters."""
# Number of examples per epoch of training data.
self.num_examples_per_epoch = 586363
# Optimizer for training the model.
self.optimizer = "SGD"
# Learning rate for the initial phase of training.
self.initial_learning_rate = 2.0
self.learning_rate_decay_factor = 0.5
self.num_epochs_per_decay = 8.0
# Learning rate when fine tuning the Inception v3 parameters.
self.train_inception_learning_rate = 0.0005
# If not None, clip gradients to this value.
self.clip_gradients = 5.0
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts MSCOCO data to TFRecord file format with SequenceExample protos.
The MSCOCO images are expected to reside in JPEG files located in the following
directory structure:
train_image_dir/COCO_train2014_000000000151.jpg
train_image_dir/COCO_train2014_000000000260.jpg
...
and
val_image_dir/COCO_val2014_000000000042.jpg
val_image_dir/COCO_val2014_000000000073.jpg
...
The MSCOCO annotations JSON files are expected to reside in train_captions_file
and val_captions_file respectively.
This script converts the combined MSCOCO data into sharded data files consisting
of 256, 4 and 8 TFRecord files, respectively:
output_dir/train-00000-of-00256
output_dir/train-00001-of-00256
...
output_dir/train-00255-of-00256
and
output_dir/val-00000-of-00004
...
output_dir/val-00003-of-00004
and
output_dir/test-00000-of-00008
...
output_dir/test-00007-of-00008
Each TFRecord file contains ~2300 records. Each record within the TFRecord file
is a serialized SequenceExample proto consisting of precisely one image-caption
pair. Note that each image has multiple captions (usually 5) and therefore each
image is replicated multiple times in the TFRecord files.
The SequenceExample proto contains the following fields:
context:
image/image_id: integer MSCOCO image identifier
image/data: string containing JPEG encoded image in RGB colorspace
feature_lists:
image/caption: list of strings containing the (tokenized) caption words
image/caption_ids: list of integer ids corresponding to the caption words
The captions are tokenized using the NLTK (http://www.nltk.org/) word tokenizer.
The vocabulary of word identifiers is constructed from the sorted list (by
descending frequency) of word tokens in the training set. Only tokens appearing
at least 4 times are considered; all other words get the "unknown" word id.
NOTE: This script will consume around 100GB of disk space because each image
in the MSCOCO dataset is replicated ~5 times (once per caption) in the output.
This is done for two reasons:
1. In order to better shuffle the training data.
2. It makes it easier to perform asynchronous preprocessing of each image in
TensorFlow.
Running this script using 16 threads may take around 1 hour on a HP Z420.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import Counter
from collections import namedtuple
from datetime import datetime
import json
import os.path
import random
import sys
import threading
import nltk.tokenize
import numpy as np
import tensorflow as tf
tf.flags.DEFINE_string("train_image_dir", "/tmp/train2014/",
"Training image directory.")
tf.flags.DEFINE_string("val_image_dir", "/tmp/val2014",
"Validation image directory.")
tf.flags.DEFINE_string("train_captions_file", "/tmp/captions_train2014.json",
"Training captions JSON file.")
tf.flags.DEFINE_string("val_captions_file", "/tmp/captions_train2014.json",
"Validation captions JSON file.")
tf.flags.DEFINE_string("output_dir", "/tmp/", "Output data directory.")
tf.flags.DEFINE_integer("train_shards", 256,
"Number of shards in training TFRecord files.")
tf.flags.DEFINE_integer("val_shards", 4,
"Number of shards in validation TFRecord files.")
tf.flags.DEFINE_integer("test_shards", 8,
"Number of shards in testing TFRecord files.")
tf.flags.DEFINE_string("start_word", "<S>",
"Special word added to the beginning of each sentence.")
tf.flags.DEFINE_string("end_word", "</S>",
"Special word added to the end of each sentence.")
tf.flags.DEFINE_string("unknown_word", "<UNK>",
"Special word meaning 'unknown'.")
tf.flags.DEFINE_integer("min_word_count", 4,
"The minimum number of occurrences of each word in the "
"training set for inclusion in the vocabulary.")
tf.flags.DEFINE_string("word_counts_output_file", "/tmp/word_counts.txt",
"Output vocabulary file of word counts.")
tf.flags.DEFINE_integer("num_threads", 8,
"Number of threads to preprocess the images.")
FLAGS = tf.flags.FLAGS
ImageMetadata = namedtuple("ImageMetadata",
["image_id", "filename", "captions"])
class Vocabulary(object):
"""Simple vocabulary wrapper."""
def __init__(self, vocab, unk_id):
"""Initializes the vocabulary.
Args:
vocab: A dictionary of word to word_id.
unk_id: Id of the special 'unknown' word.
"""
self._vocab = vocab
self._unk_id = unk_id
def word_to_id(self, word):
"""Returns the integer id of a word string."""
if word in self._vocab:
return self._vocab[word]
else:
return self._unk_id
class ImageDecoder(object):
"""Helper class for decoding images in TensorFlow."""
def __init__(self):
# Create a single TensorFlow Session for all image decoding calls.
self._sess = tf.Session()
# TensorFlow ops for JPEG decoding.
self._encoded_jpeg = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._encoded_jpeg, channels=3)
def decode_jpeg(self, encoded_jpeg):
image = self._sess.run(self._decode_jpeg,
feed_dict={self._encoded_jpeg: encoded_jpeg})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _int64_feature(value):
"""Wrapper for inserting an int64 Feature into a SequenceExample proto."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
"""Wrapper for inserting a bytes Feature into a SequenceExample proto."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(value)]))
def _int64_feature_list(values):
"""Wrapper for inserting an int64 FeatureList into a SequenceExample proto."""
return tf.train.FeatureList(feature=[_int64_feature(v) for v in values])
def _bytes_feature_list(values):
"""Wrapper for inserting a bytes FeatureList into a SequenceExample proto."""
return tf.train.FeatureList(feature=[_bytes_feature(v) for v in values])
def _to_sequence_example(image, decoder, vocab):
"""Builds a SequenceExample proto for an image-caption pair.
Args:
image: An ImageMetadata object.
decoder: An ImageDecoder object.
vocab: A Vocabulary object.
Returns:
A SequenceExample proto.
"""
with tf.gfile.FastGFile(image.filename, "r") as f:
encoded_image = f.read()
try:
decoder.decode_jpeg(encoded_image)
except (tf.errors.InvalidArgumentError, AssertionError):
print("Skipping file with invalid JPEG data: %s" % image.filename)
return
context = tf.train.Features(feature={
"image/image_id": _int64_feature(image.image_id),
"image/data": _bytes_feature(encoded_image),
})
assert len(image.captions) == 1
caption = image.captions[0]
caption_ids = [vocab.word_to_id(word) for word in caption]
feature_lists = tf.train.FeatureLists(feature_list={
"image/caption": _bytes_feature_list(caption),
"image/caption_ids": _int64_feature_list(caption_ids)
})
sequence_example = tf.train.SequenceExample(
context=context, feature_lists=feature_lists)
return sequence_example
def _process_image_files(thread_index, ranges, name, images, decoder, vocab,
num_shards):
"""Processes and saves a subset of images as TFRecord files in one thread.
Args:
thread_index: Integer thread identifier within [0, len(ranges)].
ranges: A list of pairs of integers specifying the ranges of the dataset to
process in parallel.
name: Unique identifier specifying the dataset.
images: List of ImageMetadata.
decoder: An ImageDecoder object.
vocab: A Vocabulary object.
num_shards: Integer number of shards for the output files.
"""
# Each thread produces N shards where N = num_shards / num_threads. For
# instance, if num_shards = 128, and num_threads = 2, then the first thread
# would produce shards [0, 64).
num_threads = len(ranges)
assert not num_shards % num_threads
num_shards_per_batch = int(num_shards / num_threads)
shard_ranges = np.linspace(ranges[thread_index][0], ranges[thread_index][1],
num_shards_per_batch + 1).astype(int)
num_images_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
counter = 0
for s in xrange(num_shards_per_batch):
# Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
shard = thread_index * num_shards_per_batch + s
output_filename = "%s-%.5d-of-%.5d" % (name, shard, num_shards)
output_file = os.path.join(FLAGS.output_dir, output_filename)
writer = tf.python_io.TFRecordWriter(output_file)
shard_counter = 0
images_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
for i in images_in_shard:
image = images[i]
sequence_example = _to_sequence_example(image, decoder, vocab)
if sequence_example is not None:
writer.write(sequence_example.SerializeToString())
shard_counter += 1
counter += 1
if not counter % 1000:
print("%s [thread %d]: Processed %d of %d items in thread batch." %
(datetime.now(), thread_index, counter, num_images_in_thread))
sys.stdout.flush()
print("%s [thread %d]: Wrote %d image-caption pairs to %s" %
(datetime.now(), thread_index, shard_counter, output_file))
sys.stdout.flush()
shard_counter = 0
print("%s [thread %d]: Wrote %d image-caption pairs to %d shards." %
(datetime.now(), thread_index, counter, num_shards_per_batch))
sys.stdout.flush()
def _process_dataset(name, images, vocab, num_shards):
"""Processes a complete data set and saves it as a TFRecord.
Args:
name: Unique identifier specifying the dataset.
images: List of ImageMetadata.
vocab: A Vocabulary object.
num_shards: Integer number of shards for the output files.
"""
# Break up each image into a separate entity for each caption.
images = [ImageMetadata(image.image_id, image.filename, [caption])
for image in images for caption in image.captions]
# Shuffle the ordering of images. Make the randomization repeatable.
random.seed(12345)
random.shuffle(images)
# Break the images into num_threads batches. Batch i is defined as
# images[ranges[i][0]:ranges[i][1]].
num_threads = min(num_shards, FLAGS.num_threads)
spacing = np.linspace(0, len(images), num_threads + 1).astype(np.int)
ranges = []
threads = []
for i in xrange(len(spacing) - 1):
ranges.append([spacing[i], spacing[i + 1]])
# Create a mechanism for monitoring when all threads are finished.
coord = tf.train.Coordinator()
# Create a utility for decoding JPEG images to run sanity checks.
decoder = ImageDecoder()
# Launch a thread for each batch.
print("Launching %d threads for spacings: %s" % (num_threads, ranges))
for thread_index in xrange(len(ranges)):
args = (thread_index, ranges, name, images, decoder, vocab, num_shards)
t = threading.Thread(target=_process_image_files, args=args)
t.start()
threads.append(t)
# Wait for all the threads to terminate.
coord.join(threads)
print("%s: Finished processing all %d image-caption pairs in data set '%s'." %
(datetime.now(), len(images), name))
def _create_vocab(captions):
"""Creates the vocabulary of word to word_id.
The vocabulary is saved to disk in a text file of word counts. The id of each
word in the file is its corresponding 0-based line number.
Args:
captions: A list of lists of strings.
Returns:
A Vocabulary object.
"""
print("Creating vocabulary.")
counter = Counter()
for c in captions:
counter.update(c)
print("Total words:", len(counter))
# Filter uncommon words and sort by descending count.
word_counts = [x for x in counter.items() if x[1] >= FLAGS.min_word_count]
word_counts.sort(key=lambda x: x[1], reverse=True)
print("Words in vocabulary:", len(word_counts))
# Write out the word counts file.
with tf.gfile.FastGFile(FLAGS.word_counts_output_file, "w") as f:
f.write("\n".join(["%s %d" % (w, c) for w, c in word_counts]))
print("Wrote vocabulary file:", FLAGS.word_counts_output_file)
# Create the vocabulary dictionary.
reverse_vocab = [x[0] for x in word_counts]
unk_id = len(reverse_vocab)
vocab_dict = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
vocab = Vocabulary(vocab_dict, unk_id)
return vocab
def _process_caption(caption):
"""Processes a caption string into a list of tonenized words.
Args:
caption: A string caption.
Returns:
A list of strings; the tokenized caption.
"""
tokenized_caption = [FLAGS.start_word]
tokenized_caption.extend(nltk.tokenize.word_tokenize(caption.lower()))
tokenized_caption.append(FLAGS.end_word)
return tokenized_caption
def _load_and_process_metadata(captions_file, image_dir):
"""Loads image metadata from a JSON file and processes the captions.
Args:
captions_file: JSON file containing caption annotations.
image_dir: Directory containing the image files.
Returns:
A list of ImageMetadata.
"""
with tf.gfile.FastGFile(captions_file, "r") as f:
caption_data = json.load(f)
# Extract the filenames.
id_to_filename = [(x["id"], x["file_name"]) for x in caption_data["images"]]
# Extract the captions. Each image_id is associated with multiple captions.
id_to_captions = {}
for annotation in caption_data["annotations"]:
image_id = annotation["image_id"]
caption = annotation["caption"]
id_to_captions.setdefault(image_id, [])
id_to_captions[image_id].append(caption)
assert len(id_to_filename) == len(id_to_captions)
assert set([x[0] for x in id_to_filename]) == set(id_to_captions.keys())
print("Loaded caption metadata for %d images from %s" %
(len(id_to_filename), captions_file))
# Process the captions and combine the data into a list of ImageMetadata.
print("Proccessing captions.")
image_metadata = []
num_captions = 0
for image_id, base_filename in id_to_filename:
filename = os.path.join(image_dir, base_filename)
captions = [_process_caption(c) for c in id_to_captions[image_id]]
image_metadata.append(ImageMetadata(image_id, filename, captions))
num_captions += len(captions)
print("Finished processing %d captions for %d images in %s" %
(num_captions, len(id_to_filename), captions_file))
return image_metadata
def main(unused_argv):
def _is_valid_num_shards(num_shards):
"""Returns True if num_shards is compatible with FLAGS.num_threads."""
return num_shards < FLAGS.num_threads or not num_shards % FLAGS.num_threads
assert _is_valid_num_shards(FLAGS.train_shards), (
"Please make the FLAGS.num_threads commensurate with FLAGS.train_shards")
assert _is_valid_num_shards(FLAGS.val_shards), (
"Please make the FLAGS.num_threads commensurate with FLAGS.val_shards")
assert _is_valid_num_shards(FLAGS.test_shards), (
"Please make the FLAGS.num_threads commensurate with FLAGS.test_shards")
if not tf.gfile.IsDirectory(FLAGS.output_dir):
tf.gfile.MakeDirs(FLAGS.output_dir)
# Load image metadata from caption files.
mscoco_train_dataset = _load_and_process_metadata(FLAGS.train_captions_file,
FLAGS.train_image_dir)
mscoco_val_dataset = _load_and_process_metadata(FLAGS.val_captions_file,
FLAGS.val_image_dir)
# Redistribute the MSCOCO data as follows:
# train_dataset = 100% of mscoco_train_dataset + 85% of mscoco_val_dataset.
# val_dataset = 5% of mscoco_val_dataset (for validation during training).
# test_dataset = 10% of mscoco_val_dataset (for final evaluation).
train_cutoff = int(0.85 * len(mscoco_val_dataset))
val_cutoff = int(0.90 * len(mscoco_val_dataset))
train_dataset = mscoco_train_dataset + mscoco_val_dataset[0:train_cutoff]
val_dataset = mscoco_val_dataset[train_cutoff:val_cutoff]
test_dataset = mscoco_val_dataset[val_cutoff:]
# Create vocabulary from the training captions.
train_captions = [c for image in train_dataset for c in image.captions]
vocab = _create_vocab(train_captions)
_process_dataset("train", train_dataset, vocab, FLAGS.train_shards)
_process_dataset("val", val_dataset, vocab, FLAGS.val_shards)
_process_dataset("test", test_dataset, vocab, FLAGS.test_shards)
if __name__ == "__main__":
tf.app.run()
#!/bin/bash
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Script to download and preprocess the MSCOCO data set.
#
# The outputs of this script are sharded TFRecord files containing serialized
# SequenceExample protocol buffers. See build_mscoco_data.py for details of how
# the SequenceExample protocol buffers are constructed.
#
# usage:
# ./download_and_preprocess_mscoco.sh
set -e
if [ -z "$1" ]; then
echo "usage download_and_preproces_mscoco.sh [data dir]"
exit
fi
# Create the output directories.
OUTPUT_DIR="${1%/}"
SCRATCH_DIR="${OUTPUT_DIR}/raw-data"
mkdir -p "${OUTPUT_DIR}"
mkdir -p "${SCRATCH_DIR}"
CURRENT_DIR=$(pwd)
WORK_DIR="$0.runfiles/__main__/im2txt"
# Helper function to download and unpack a .zip file.
function download_and_unzip() {
local BASE_URL=${1}
local FILENAME=${2}
if [ ! -f ${FILENAME} ]; then
echo "Downloading ${FILENAME} to $(pwd)"
wget -nd -c "${BASE_URL}/${FILENAME}"
else
echo "Skipping download of ${FILENAME}"
fi
echo "Unzipping ${FILENAME}"
unzip -nq ${FILENAME}
}
cd ${SCRATCH_DIR}
# Download the images.
BASE_IMAGE_URL="http://msvocds.blob.core.windows.net/coco2014"
TRAIN_IMAGE_FILE="train2014.zip"
download_and_unzip ${BASE_IMAGE_URL} ${TRAIN_IMAGE_FILE}
TRAIN_IMAGE_DIR="${SCRATCH_DIR}/train2014"
VAL_IMAGE_FILE="val2014.zip"
download_and_unzip ${BASE_IMAGE_URL} ${VAL_IMAGE_FILE}
VAL_IMAGE_DIR="${SCRATCH_DIR}/val2014"
# Download the captions.
BASE_CAPTIONS_URL="http://msvocds.blob.core.windows.net/annotations-1-0-3"
CAPTIONS_FILE="captions_train-val2014.zip"
download_and_unzip ${BASE_CAPTIONS_URL} ${CAPTIONS_FILE}
TRAIN_CAPTIONS_FILE="${SCRATCH_DIR}/annotations/captions_train2014.json"
VAL_CAPTIONS_FILE="${SCRATCH_DIR}/annotations/captions_val2014.json"
# Build TFRecords of the image data.
cd "${CURRENT_DIR}"
BUILD_SCRIPT="${WORK_DIR}/build_mscoco_data"
"${BUILD_SCRIPT}" \
--train_image_dir="${TRAIN_IMAGE_DIR}" \
--val_image_dir="${VAL_IMAGE_DIR}" \
--train_captions_file="${TRAIN_CAPTIONS_FILE}" \
--val_captions_file="${VAL_CAPTIONS_FILE}" \
--output_dir="${OUTPUT_DIR}" \
--word_counts_output_file="${OUTPUT_DIR}/word_counts.txt" \
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluate the model.
This script should be run concurrently with training so that summaries show up
in TensorBoard.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os.path
import time
import numpy as np
import tensorflow as tf
from im2txt import configuration
from im2txt import show_and_tell_model
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string("input_file_pattern", "",
"File pattern of sharded TFRecord input files.")
tf.flags.DEFINE_string("checkpoint_dir", "",
"Directory containing model checkpoints.")
tf.flags.DEFINE_string("eval_dir", "", "Directory to write event logs.")
tf.flags.DEFINE_integer("eval_interval_secs", 600,
"Interval between evaluation runs.")
tf.flags.DEFINE_integer("num_eval_examples", 10132,
"Number of examples for evaluation.")
tf.flags.DEFINE_integer("min_global_step", 5000,
"Minimum global step to run evaluation.")
tf.logging.set_verbosity(tf.logging.INFO)
def evaluate_model(sess, model, global_step, summary_writer, summary_op):
"""Computes perplexity-per-word over the evaluation dataset.
Summaries and perplexity-per-word are written out to the eval directory.
Args:
sess: Session object.
model: Instance of ShowAndTellModel; the model to evaluate.
global_step: Integer; global step of the model checkpoint.
summary_writer: Instance of SummaryWriter.
summary_op: Op for generating model summaries.
"""
# Log model summaries on a single batch.
summary_str = sess.run(summary_op)
summary_writer.add_summary(summary_str, global_step)
# Compute perplexity over the entire dataset.
num_eval_batches = int(
math.ceil(FLAGS.num_eval_examples / model.config.batch_size))
start_time = time.time()
sum_losses = 0.
sum_weights = 0.
for i in xrange(num_eval_batches):
cross_entropy_losses, weights = sess.run([
model.target_cross_entropy_losses,
model.target_cross_entropy_loss_weights
])
sum_losses += np.sum(cross_entropy_losses * weights)
sum_weights += np.sum(weights)
if not i % 100:
tf.logging.info("Computed losses for %d of %d batches.", i + 1,
num_eval_batches)
eval_time = time.time() - start_time
perplexity = math.exp(sum_losses / sum_weights)
tf.logging.info("Perplexity = %f (%.2g sec)", perplexity, eval_time)
# Log perplexity to the SummaryWriter.
summary = tf.Summary()
value = summary.value.add()
value.simple_value = perplexity
value.tag = "Perplexity"
summary_writer.add_summary(summary, global_step)
# Write the Events file to the eval directory.
summary_writer.flush()
tf.logging.info("Finished processing evaluation at global step %d.",
global_step)
def run_once(model, summary_writer, summary_op):
"""Evaluates the latest model checkpoint.
Args:
model: Instance of ShowAndTellModel; the model to evaluate.
summary_writer: Instance of SummaryWriter.
summary_op: Op for generating model summaries.
"""
model_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if not model_path:
tf.logging.info("Skipping evaluation. No checkpoint found in: %s",
FLAGS.checkpoint_dir)
return
with tf.Session() as sess:
# Load model from checkpoint.
tf.logging.info("Loading model from checkpoint: %s", model_path)
model.saver.restore(sess, model_path)
global_step = tf.train.global_step(sess, model.global_step.name)
tf.logging.info("Successfully loaded %s at global step = %d.",
os.path.basename(model_path), global_step)
if global_step < FLAGS.min_global_step:
tf.logging.info("Skipping evaluation. Global step = %d < %d", global_step,
FLAGS.min_global_step)
return
# Start the queue runners.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
# Run evaluation on the latest checkpoint.
try:
evaluate_model(
sess=sess,
model=model,
global_step=global_step,
summary_writer=summary_writer,
summary_op=summary_op)
except Exception, e: # pylint: disable=broad-except
tf.logging.error("Evaluation failed.")
coord.request_stop(e)
coord.request_stop()
coord.join(threads, stop_grace_period_secs=10)
def run():
"""Runs evaluation in a loop, and logs summaries to TensorBoard."""
# Create the evaluation directory if it doesn't exist.
eval_dir = FLAGS.eval_dir
if not tf.gfile.IsDirectory(eval_dir):
tf.logging.info("Creating eval directory: %s", eval_dir)
tf.gfile.MakeDirs(eval_dir)
g = tf.Graph()
with g.as_default():
# Build the model for evaluation.
model_config = configuration.ModelConfig()
model_config.input_file_pattern = FLAGS.input_file_pattern
model = show_and_tell_model.ShowAndTellModel(model_config, mode="eval")
model.build()
# Create the summary operation and the summary writer.
summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter(eval_dir)
g.finalize()
# Run a new evaluation run every eval_interval_secs.
while True:
start = time.time()
tf.logging.info("Starting evaluation at " + time.strftime(
"%Y-%m-%d-%H:%M:%S", time.localtime()))
run_once(model, summary_writer, summary_op)
time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
if time_to_next_eval > 0:
time.sleep(time_to_next_eval)
def main(unused_argv):
assert FLAGS.input_file_pattern, "--input_file_pattern is required"
assert FLAGS.checkpoint_dir, "--checkpoint_dir is required"
assert FLAGS.eval_dir, "--eval_dir is required"
run()
if __name__ == "__main__":
tf.app.run()
package(default_visibility = ["//im2txt:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "inference_wrapper_base",
srcs = ["inference_wrapper_base.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "vocabulary",
srcs = ["vocabulary.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "caption_generator",
srcs = ["caption_generator.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "caption_generator_test",
srcs = ["caption_generator_test.py"],
deps = [
":caption_generator",
],
)
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class for generating captions from an image-to-text model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import heapq
import math
import numpy as np
class Caption(object):
"""Represents a complete or partial caption."""
def __init__(self, sentence, state, logprob, score, metadata=None):
"""Initializes the Caption.
Args:
sentence: List of word ids in the caption.
state: Model state after generating the previous word.
logprob: Log-probability of the caption.
score: Score of the caption.
metadata: Optional metadata associated with the partial sentence. If not
None, a list of strings with the same length as 'sentence'.
"""
self.sentence = sentence
self.state = state
self.logprob = logprob
self.score = score
self.metadata = metadata
def __cmp__(self, other):
"""Compares Captions by score."""
assert isinstance(other, Caption)
if self.score == other.score:
return 0
elif self.score < other.score:
return -1
else:
return 1
class TopN(object):
"""Maintains the top n elements of an incrementally provided set."""
def __init__(self, n):
self._n = n
self._data = []
def size(self):
assert self._data is not None
return len(self._data)
def push(self, x):
"""Pushes a new element."""
assert self._data is not None
if len(self._data) < self._n:
heapq.heappush(self._data, x)
else:
heapq.heappushpop(self._data, x)
def extract(self, sort=False):
"""Extracts all elements from the TopN. This is a destructive operation.
The only method that can be called immediately after extract() is reset().
Args:
sort: Whether to return the elements in descending sorted order.
Returns:
A list of data; the top n elements provided to the set.
"""
assert self._data is not None
data = self._data
self._data = None
if sort:
data.sort(reverse=True)
return data
def reset(self):
"""Returns the TopN to an empty state."""
self._data = []
class CaptionGenerator(object):
"""Class to generate captions from an image-to-text model."""
def __init__(self,
model,
vocab,
beam_size=3,
max_caption_length=20,
length_normalization_factor=0.0):
"""Initializes the generator.
Args:
model: Object encapsulating a trained image-to-text model. Must have
methods feed_image() and inference_step(). For example, an instance of
InferenceWrapperBase.
vocab: A Vocabulary object.
beam_size: Beam size to use when generating captions.
max_caption_length: The maximum caption length before stopping the search.
length_normalization_factor: If != 0, a number x such that captions are
scored by logprob/length^x, rather than logprob. This changes the
relative scores of captions depending on their lengths. For example, if
x > 0 then longer captions will be favored.
"""
self.vocab = vocab
self.model = model
self.beam_size = beam_size
self.max_caption_length = max_caption_length
self.length_normalization_factor = length_normalization_factor
def beam_search(self, sess, encoded_image):
"""Runs beam search caption generation on a single image.
Args:
sess: TensorFlow Session object.
encoded_image: An encoded image string.
Returns:
A list of Caption sorted by descending score.
"""
# Feed in the image to get the initial state.
initial_state = self.model.feed_image(sess, encoded_image)
initial_beam = Caption(
sentence=[self.vocab.start_id],
state=initial_state[0],
logprob=0.0,
score=0.0,
metadata=[""])
partial_captions = TopN(self.beam_size)
partial_captions.push(initial_beam)
complete_captions = TopN(self.beam_size)
# Run beam search.
for _ in range(self.max_caption_length - 1):
partial_captions_list = partial_captions.extract()
partial_captions.reset()
input_feed = np.array([c.sentence[-1] for c in partial_captions_list])
state_feed = np.array([c.state for c in partial_captions_list])
softmax, new_states, metadata = self.model.inference_step(sess,
input_feed,
state_feed)
for i, partial_caption in enumerate(partial_captions_list):
word_probabilities = softmax[i]
state = new_states[i]
# For this partial caption, get the beam_size most probable next words.
words_and_probs = list(enumerate(word_probabilities))
words_and_probs.sort(key=lambda x: -x[1])
words_and_probs = words_and_probs[0:self.beam_size]
# Each next word gives a new partial caption.
for w, p in words_and_probs:
if p < 1e-12:
continue # Avoid log(0).
sentence = partial_caption.sentence + [w]
logprob = partial_caption.logprob + math.log(p)
score = logprob
if metadata:
metadata_list = partial_caption.metadata + [metadata[i]]
else:
metadata_list = None
if w == self.vocab.end_id:
if self.length_normalization_factor > 0:
score /= len(sentence)**self.length_normalization_factor
beam = Caption(sentence, state, logprob, score, metadata_list)
complete_captions.push(beam)
else:
beam = Caption(sentence, state, logprob, score, metadata_list)
partial_captions.push(beam)
if partial_captions.size() == 0:
# We have run out of partial candidates; happens when beam_size = 1.
break
# If we have no complete captions then fall back to the partial captions.
# But never output a mixture of complete and partial captions because a
# partial caption could have a higher score than all the complete captions.
if not complete_captions.size():
complete_captions = partial_captions
return complete_captions.extract(sort=True)
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for CaptionGenerator."""
import math
import numpy as np
import tensorflow as tf
from im2txt.inference_utils import caption_generator
class FakeVocab(object):
"""Fake Vocabulary for testing purposes."""
def __init__(self):
self.start_id = 0 # Word id denoting sentence start.
self.end_id = 1 # Word id denoting sentence end.
class FakeModel(object):
"""Fake model for testing purposes."""
def __init__(self):
# Number of words in the vocab.
self._vocab_size = 12
# Dimensionality of the nominal model state.
self._state_size = 1
# Map of previous word to the probability distribution of the next word.
self._probabilities = {
0: {1: 0.1,
2: 0.2,
3: 0.3,
4: 0.4},
2: {5: 0.1,
6: 0.9},
3: {1: 0.1,
7: 0.4,
8: 0.5},
4: {1: 0.3,
9: 0.3,
10: 0.4},
5: {1: 1.0},
6: {1: 1.0},
7: {1: 1.0},
8: {1: 1.0},
9: {1: 0.5,
11: 0.5},
10: {1: 1.0},
11: {1: 1.0},
}
# pylint: disable=unused-argument
def feed_image(self, sess, encoded_image):
# Return a nominal model state.
return np.zeros([1, self._state_size])
def inference_step(self, sess, input_feed, state_feed):
# Compute the matrix of softmax distributions for the next batch of words.
batch_size = input_feed.shape[0]
softmax_output = np.zeros([batch_size, self._vocab_size])
for batch_index, word_id in enumerate(input_feed):
for next_word, probability in self._probabilities[word_id].items():
softmax_output[batch_index, next_word] = probability
# Nominal state and metadata.
new_state = np.zeros([batch_size, self._state_size])
metadata = None
return softmax_output, new_state, metadata
# pylint: enable=unused-argument
class CaptionGeneratorTest(tf.test.TestCase):
def _assertExpectedCaptions(self,
expected_captions,
beam_size=3,
max_caption_length=20,
length_normalization_factor=0):
"""Tests that beam search generates the expected captions.
Args:
expected_captions: A sequence of pairs (sentence, probability), where
sentence is a list of integer ids and probability is a float in [0, 1].
beam_size: Parameter passed to beam_search().
max_caption_length: Parameter passed to beam_search().
length_normalization_factor: Parameter passed to beam_search().
"""
expected_sentences = [c[0] for c in expected_captions]
expected_probabilities = [c[1] for c in expected_captions]
# Generate captions.
generator = caption_generator.CaptionGenerator(
model=FakeModel(),
vocab=FakeVocab(),
beam_size=beam_size,
max_caption_length=max_caption_length,
length_normalization_factor=length_normalization_factor)
actual_captions = generator.beam_search(sess=None, encoded_image=None)
actual_sentences = [c.sentence for c in actual_captions]
actual_probabilities = [math.exp(c.logprob) for c in actual_captions]
self.assertEqual(expected_sentences, actual_sentences)
self.assertAllClose(expected_probabilities, actual_probabilities)
def testBeamSize(self):
# Beam size = 1.
expected = [([0, 4, 10, 1], 0.16)]
self._assertExpectedCaptions(expected, beam_size=1)
# Beam size = 2.
expected = [([0, 4, 10, 1], 0.16), ([0, 3, 8, 1], 0.15)]
self._assertExpectedCaptions(expected, beam_size=2)
# Beam size = 3.
expected = [
([0, 2, 6, 1], 0.18), ([0, 4, 10, 1], 0.16), ([0, 3, 8, 1], 0.15)
]
self._assertExpectedCaptions(expected, beam_size=3)
def testMaxLength(self):
# Max length = 1.
expected = [([0], 1.0)]
self._assertExpectedCaptions(expected, max_caption_length=1)
# Max length = 2.
# There are no complete sentences, so partial sentences are returned.
expected = [([0, 4], 0.4), ([0, 3], 0.3), ([0, 2], 0.2)]
self._assertExpectedCaptions(expected, max_caption_length=2)
# Max length = 3.
# There is at least one complete sentence, so only complete sentences are
# returned.
expected = [([0, 4, 1], 0.12), ([0, 3, 1], 0.03)]
self._assertExpectedCaptions(expected, max_caption_length=3)
# Max length = 4.
expected = [
([0, 2, 6, 1], 0.18), ([0, 4, 10, 1], 0.16), ([0, 3, 8, 1], 0.15)
]
self._assertExpectedCaptions(expected, max_caption_length=4)
def testLengthNormalization(self):
# Length normalization factor = 3.
# The longest caption is returned first, despite having low probability,
# because it has the highest log(probability)/length**3.
expected = [
([0, 4, 9, 11, 1], 0.06),
([0, 2, 6, 1], 0.18),
([0, 4, 10, 1], 0.16),
([0, 3, 8, 1], 0.15),
]
self._assertExpectedCaptions(
expected, beam_size=4, length_normalization_factor=3)
if __name__ == '__main__':
tf.test.main()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base wrapper class for performing inference with an image-to-text model.
Subclasses must implement the following methods:
build_model():
Builds the model for inference and returns the model object.
feed_image():
Takes an encoded image and returns the initial model state, where "state"
is a numpy array whose specifics are defined by the subclass, e.g.
concatenated LSTM state. It's assumed that feed_image() will be called
precisely once at the start of inference for each image. Subclasses may
compute and/or save per-image internal context in this method.
inference_step():
Takes a batch of inputs and states at a single time-step. Returns the
softmax output corresponding to the inputs, and the new states of the batch.
Optionally also returns metadata about the current inference step, e.g. a
serialized numpy array containing activations from a particular model layer.
Client usage:
1. Build the model inference graph via build_graph_from_config() or
build_graph_from_proto().
2. Call the resulting restore_fn to load the model checkpoint.
3. For each image in a batch of images:
a) Call feed_image() once to get the initial state.
b) For each step of caption generation, call inference_step().
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import tensorflow as tf
# pylint: disable=unused-argument
class InferenceWrapperBase(object):
"""Base wrapper class for performing inference with an image-to-text model."""
def __init__(self):
pass
def build_model(self, model_config):
"""Builds the model for inference.
Args:
model_config: Object containing configuration for building the model.
Returns:
model: The model object.
"""
tf.logging.fatal("Please implement build_model in subclass")
def _create_restore_fn(self, checkpoint_path, saver):
"""Creates a function that restores a model from checkpoint.
Args:
checkpoint_path: Checkpoint file or a directory containing a checkpoint
file.
saver: Saver for restoring variables from the checkpoint file.
Returns:
restore_fn: A function such that restore_fn(sess) loads model variables
from the checkpoint file.
Raises:
ValueError: If checkpoint_path does not refer to a checkpoint file or a
directory containing a checkpoint file.
"""
if tf.gfile.IsDirectory(checkpoint_path):
checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
if not checkpoint_path:
raise ValueError("No checkpoint file found in: %s" % checkpoint_path)
def _restore_fn(sess):
tf.logging.info("Loading model from checkpoint: %s", checkpoint_path)
saver.restore(sess, checkpoint_path)
tf.logging.info("Successfully loaded checkpoint: %s",
os.path.basename(checkpoint_path))
return _restore_fn
def build_graph_from_config(self, model_config, checkpoint_path):
"""Builds the inference graph from a configuration object.
Args:
model_config: Object containing configuration for building the model.
checkpoint_path: Checkpoint file or a directory containing a checkpoint
file.
Returns:
restore_fn: A function such that restore_fn(sess) loads model variables
from the checkpoint file.
"""
tf.logging.info("Building model.")
model = self.build_model(model_config)
saver = model.saver
if not saver:
saver = tf.Saver()
return self._create_restore_fn(checkpoint_path, saver)
def build_graph_from_proto(self, graph_def_file, saver_def_file,
checkpoint_path):
"""Builds the inference graph from serialized GraphDef and SaverDef protos.
Args:
graph_def_file: File containing a serialized GraphDef proto.
saver_def_file: File containing a serialized SaverDef proto.
checkpoint_path: Checkpoint file or a directory containing a checkpoint
file.
Returns:
restore_fn: A function such that restore_fn(sess) loads model variables
from the checkpoint file.
"""
# Load the Graph.
tf.logging.info("Loading GraphDef from file: %s", graph_def_file)
graph_def = tf.GraphDef()
with tf.gfile.FastGFile(graph_def_file, "rb") as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name="")
# Load the Saver.
tf.logging.info("Loading SaverDef from file: %s", saver_def_file)
saver_def = tf.train.SaverDef()
with tf.gfile.FastGFile(saver_def_file, "rb") as f:
saver_def.ParseFromString(f.read())
saver = tf.train.Saver(saver_def=saver_def)
return self._create_restore_fn(checkpoint_path, saver)
def feed_image(self, sess, encoded_image):
"""Feeds an image and returns the initial model state.
See comments at the top of file.
Args:
sess: TensorFlow Session object.
encoded_image: An encoded image string.
Returns:
state: A numpy array of shape [1, state_size].
"""
tf.logging.fatal("Please implement feed_image in subclass")
def inference_step(self, sess, input_feed, state_feed):
"""Runs one step of inference.
Args:
sess: TensorFlow Session object.
input_feed: A numpy array of shape [batch_size].
state_feed: A numpy array of shape [batch_size, state_size].
Returns:
softmax_output: A numpy array of shape [batch_size, vocab_size].
new_state: A numpy array of shape [batch_size, state_size].
metadata: Optional. If not None, a string containing metadata about the
current inference step (e.g. serialized numpy array containing
activations from a particular model layer.).
"""
tf.logging.fatal("Please implement inference_step in subclass")
# pylint: enable=unused-argument
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Vocabulary class for an image-to-text model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class Vocabulary(object):
"""Vocabulary class for an image-to-text model."""
def __init__(self,
vocab_file,
start_word="<S>",
end_word="</S>",
unk_word="<UNK>"):
"""Initializes the vocabulary.
Args:
vocab_file: File containing the vocabulary, where the words are the first
whitespace-separated token on each line (other tokens are ignored) and
the word ids are the corresponding line numbers.
start_word: Special word denoting sentence start.
end_word: Special word denoting sentence end.
unk_word: Special word denoting unknown words.
"""
if not tf.gfile.Exists(vocab_file):
tf.logging.fatal("Vocab file %s not found.", vocab_file)
tf.logging.info("Initializing vocabulary from file: %s", vocab_file)
with tf.gfile.GFile(vocab_file, mode="r") as f:
reverse_vocab = list(f.readlines())
reverse_vocab = [line.split()[0] for line in reverse_vocab]
assert start_word in reverse_vocab
assert end_word in reverse_vocab
if unk_word not in reverse_vocab:
reverse_vocab.append(unk_word)
vocab = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
tf.logging.info("Created vocabulary with %d words" % len(vocab))
self.vocab = vocab # vocab[word] = id
self.reverse_vocab = reverse_vocab # reverse_vocab[id] = word
# Save special word ids.
self.start_id = vocab[start_word]
self.end_id = vocab[end_word]
self.unk_id = vocab[unk_word]
def word_to_id(self, word):
"""Returns the integer word id of a word string."""
if word in self.vocab:
return self.vocab[word]
else:
return self.unk_id
def id_to_word(self, word_id):
"""Returns the word string of an integer word id."""
if word_id >= len(self.reverse_vocab):
return self.reverse_vocab[self.unk_id]
else:
return self.reverse_vocab[word_id]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model wrapper class for performing inference with a ShowAndTellModel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from im2txt import show_and_tell_model
from im2txt.inference_utils import inference_wrapper_base
class InferenceWrapper(inference_wrapper_base.InferenceWrapperBase):
"""Model wrapper class for performing inference with a ShowAndTellModel."""
def __init__(self):
super(InferenceWrapper, self).__init__()
def build_model(self, model_config):
model = show_and_tell_model.ShowAndTellModel(model_config, mode="inference")
model.build()
return model
def feed_image(self, sess, encoded_image):
initial_state = sess.run(fetches="lstm/initial_state:0",
feed_dict={"image_feed:0": encoded_image})
return initial_state
def inference_step(self, sess, input_feed, state_feed):
softmax_output, state_output = sess.run(
fetches=["softmax:0", "lstm/state:0"],
feed_dict={
"input_feed:0": input_feed,
"lstm/state_feed:0": state_feed,
})
return softmax_output, state_output, None
package(default_visibility = ["//im2txt:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "image_processing",
srcs = ["image_processing.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "image_embedding",
srcs = ["image_embedding.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "image_embedding_test",
size = "small",
srcs = ["image_embedding_test.py"],
deps = [
":image_embedding",
],
)
py_library(
name = "inputs",
srcs = ["inputs.py"],
srcs_version = "PY2AND3",
)
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image embedding ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_base
slim = tf.contrib.slim
def inception_v3(images,
trainable=True,
is_training=True,
weight_decay=0.00004,
stddev=0.1,
dropout_keep_prob=0.8,
use_batch_norm=True,
batch_norm_params=None,
add_summaries=True,
scope="InceptionV3"):
"""Builds an Inception V3 subgraph for image embeddings.
Args:
images: A float32 Tensor of shape [batch, height, width, channels].
trainable: Whether the inception submodel should be trainable or not.
is_training: Boolean indicating training mode or not.
weight_decay: Coefficient for weight regularization.
stddev: The standard deviation of the trunctated normal weight initializer.
dropout_keep_prob: Dropout keep probability.
use_batch_norm: Whether to use batch normalization.
batch_norm_params: Parameters for batch normalization. See
tf.contrib.layers.batch_norm for details.
add_summaries: Whether to add activation summaries.
scope: Optional Variable scope.
Returns:
end_points: A dictionary of activations from inception_v3 layers.
"""
# Only consider the inception model to be in training mode if it's trainable.
is_inception_model_training = trainable and is_training
if use_batch_norm:
# Default parameters for batch normalization.
if not batch_norm_params:
batch_norm_params = {
"is_training": is_inception_model_training,
"trainable": trainable,
# Decay for the moving averages.
"decay": 0.9997,
# Epsilon to prevent 0s in variance.
"epsilon": 0.001,
# Collection containing the moving mean and moving variance.
"variables_collections": {
"beta": None,
"gamma": None,
"moving_mean": ["moving_vars"],
"moving_variance": ["moving_vars"],
}
}
else:
batch_norm_params = None
if trainable:
weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
else:
weights_regularizer = None
with tf.variable_scope(scope, "InceptionV3", [images]) as scope:
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=weights_regularizer,
trainable=trainable):
with slim.arg_scope(
[slim.conv2d],
weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
net, end_points = inception_v3_base(images, scope=scope)
with tf.variable_scope("logits"):
shape = net.get_shape()
net = slim.avg_pool2d(net, shape[1:3], padding="VALID", scope="pool")
net = slim.dropout(
net,
keep_prob=dropout_keep_prob,
is_training=is_inception_model_training,
scope="dropout")
net = slim.flatten(net, scope="flatten")
# Add summaries.
if add_summaries:
for v in end_points.values():
tf.contrib.layers.summaries.summarize_activation(v)
return net
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment