Commit 83ee52cc authored by Martin Wicke's avatar Martin Wicke
Browse files

added inception model

parent 1ecaf090
[submodule "tensorflow"]
path = tensorflow
url = https://github.com/tensorflow/tensorflow.git
local_repository(
name = "tf",
path = __workspace_dir__ + "/tensorflow",
)
load('//tensorflow/tensorflow:workspace.bzl', 'tf_workspace')
tf_workspace("tensorflow/")
# grpc expects //external:protobuf_clib and //external:protobuf_compiler
# to point to the protobuf's compiler library.
bind(
name = "protobuf_clib",
actual = "@tf//google/protobuf:protoc_lib",
)
bind(
name = "protobuf_compiler",
actual = "@tf//google/protobuf:protoc_lib",
)
git_repository(
name = "grpc",
commit = "73979f4",
init_submodules = True,
remote = "https://github.com/grpc/grpc.git",
)
# protobuf expects //external:grpc_cpp_plugin to point to grpc's
# C++ plugin code generator.
bind(
name = "grpc_cpp_plugin",
actual = "@grpc//:grpc_cpp_plugin",
)
bind(
name = "grpc_lib",
actual = "@grpc//:grpc++_unsecure",
)
# Description:
# Example TensorFlow models for ImageNet.
package(default_visibility = [":internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = ["//inception/..."],
)
py_library(
name = "dataset",
srcs = [
"dataset.py",
],
deps = [
"@tf//tensorflow:tensorflow_py",
],
)
py_library(
name = "imagenet_data",
srcs = [
"imagenet_data.py",
],
deps = [
":dataset",
],
)
py_library(
name = "flowers_data",
srcs = [
"flowers_data.py",
],
deps = [
":dataset",
],
)
py_library(
name = "image_processing",
srcs = [
"image_processing.py",
],
)
py_library(
name = "inception",
srcs = [
"inception_model.py",
],
deps = [
"@tf//tensorflow:tensorflow_py",
":dataset",
"//inception/slim",
],
)
py_binary(
name = "imagenet_eval",
srcs = [
"imagenet_eval.py",
],
deps = [
":imagenet_data",
":inception_eval",
],
)
py_binary(
name = "flowers_eval",
srcs = [
"flowers_eval.py",
],
deps = [
":flowers_data",
":inception_eval",
],
)
py_library(
name = "inception_eval",
srcs = [
"inception_eval.py",
],
deps = [
"@tf//tensorflow:tensorflow_py",
":image_processing",
":inception",
],
)
py_binary(
name = "imagenet_train",
srcs = [
"imagenet_train.py",
],
deps = [
":imagenet_data",
":inception_train",
],
)
py_binary(
name = "flowers_train",
srcs = [
"flowers_train.py",
],
deps = [
":flowers_data",
":inception_train",
],
)
py_library(
name = "inception_train",
srcs = [
"inception_train.py",
],
deps = [
"@tf//tensorflow:tensorflow_py",
":image_processing",
":inception",
],
)
py_binary(
name = "build_image_data",
srcs = ["data/build_image_data.py"],
deps = [
"@tf//tensorflow:tensorflow_py",
],
)
sh_binary(
name = "download_and_preprocess_flowers",
srcs = ["data/download_and_preprocess_flowers.sh"],
data = [
":build_image_data",
],
)
sh_binary(
name = "download_and_preprocess_imagenet",
srcs = ["data/download_and_preprocess_imagenet.sh"],
data = [
"data/download_imagenet.sh",
"data/imagenet_2012_validation_synset_labels.txt",
"data/imagenet_lsvrc_2015_synsets.txt",
"data/imagenet_metadata.txt",
"data/preprocess_imagenet_validation_data.py",
"data/process_bounding_boxes.py",
":build_imagenet_data",
],
)
py_binary(
name = "build_imagenet_data",
srcs = ["data/build_imagenet_data.py"],
deps = [
"@tf//tensorflow:tensorflow_py",
],
)
filegroup(
name = "srcs",
srcs = glob(
[
"**/*.py",
"BUILD",
],
),
)
# Inception in TensorFlow
[TOC]
[ImageNet](http://www.image-net.org/) is a common academic data set in machine
learning for training an image recognition system. Code in this directory
demonstrates how to use TensorFlow to train and evaluate
a type of convolutional neural network (CNN) on this academic data set.
In particular, we demonstrate how to train the Inception v3 architecture
as specified in:
_Rethinking the Inception Architecture for Computer Vision_
Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens,
Zbigniew Wojna
http://arxiv.org/abs/1512.00567
This network achieves 21.2% top-1 and 5.6% top-5 error for single frame
evaluation with a computational cost of 5 billion multiply-adds per inference
and with using less than 25 million parameters. Below is a visualization
of the model architecture.
<center>
![Inception-v3 Architecture](g3doc/inception_v3_architecture.png)
</center>
## Description of Code
The code base provides three core binaries for:
* Training an Inception v3 network from scratch across multiple GPUs and/or
multiple machines using the ImageNet 2012 Challenge training data set.
* Evaluating an Inception v3 network using the ImageNet 2012 Challenge
validation data set.
* Retraining an Inception v3 network on a novel task and back-propagating the
errors to fine tune the network weights.
The training procedure employs synchronous stochastic gradient desent across
multiple GPUs. The user may specify the number of GPUs they wish harness.
The synchronous training performs *batch-splitting* by dividing a given batch
across multiple GPUs.
The training set up is nearly identical to the section [Training a Model
Using Multiple GPU Cards](https://www.tensorflow.org/tutorials/deep_cnn/index.html#training-a-model-using-multiple-gpu-cards)
where we have substituted the CIFAR-10 model architecture
with Inception v3. The primary differences with that setup are:
* Calculate and update the batch-norm statistics during training so that they
may be substituted in during evaluation.
* Specify the model architecture using a (still experimental) higher
level language called TensorFlow-Slim.
For more details about TensorFlow-Slim, please see the
[Slim README](slim/README.md). Please
note that this higher-level language is still *experimental* and the API may
change over time depending on usage and subsequent research.
## Getting Started
**NOTE** Before doing anything, we first need to build TensorFlow from source.
Please follow the instructions at
[Installing From Source](https://www.tensorflow.org/versions/r0.7/get_started/os_setup.html#installing-from-sources).
Before you run the training script for the first time, you will need to
download and convert the ImageNet data to native TFRecord format. The TFRecord
format consists of a set of sharded files where each entry is a serialized
`tf.Example` proto. Each `tf.Example` proto contains the ImageNet image (JPEG
encoded) as well as metadata such as label and bounding box information. See
[`parse_example_proto`](image_processing.py) for details.
We provide a single
[script](data/download_and_preprocess_imagenet.sh)
for downloading and converting ImageNet data to TFRecord format. Downloading
and preprocessing the data may take several hours (up to half a day) depending
on your network and computer speed. Please be patient.
To begin, you will need to sign up for an account with
[ImageNet](http://image-net.org) to gain access to the data. Look for the
sign up page, create an account and request an access key to download the data.
After you have `USERNAME` and `PASSWORD`, you are ready to run our script.
Make sure that your hard disk has at least 500 GB of free space for donwloading
and storing the data. Here we select `DATA_DIR=$HOME/imagenet-data` as such a
location but feel free to edit accordingly.
When you run the below script, please enter *USERNAME* and *PASSWORD*
when prompted. This will occur at the very beginning. Once these values are
entered, you will not need to interact with the script again.
```shell
# location of where to place the ImageNet data
DATA_DIR=$HOME/imagenet-data
# build the preprocessing script.
bazel build -c opt inception/download_and_preprocess_imagenet
# run it
bazel-bin/inception/download_and_preprocess_imagenet "${DATA_DIR}$"
```
The final line of the output script should read:
```shell
2016-02-17 14:30:17.287989: Finished writing all 1281167 images in data set.
```
When the script finishes you will find 1024 and 128 training and validation
files in the `DATA_DIR`. The files will match the patterns
`train-????-of-1024` and `validation-?????-of-00128`, respectively.
[Congratulations!](https://www.youtube.com/watch?v=9bZkp7q19f0)
You are now ready to train or evaluate with the ImageNet data set.
## How to Train from Scratch
**WARNING** Training an Inception v3 network from scratch is a computationally
intensive task and depending on your compute setup may take several days or
even weeks.
*Before proceeding* please read the [Convolutional Neural
Networks] (https://www.tensorflow.org/tutorials/deep_cnn/index.html)
tutorial in particular focus on
[Training a Model Using Multiple GPU Cards](https://www.tensorflow.org/tutorials/deep_cnn/index.html#training-a-model-using-multiple-gpu-cards)
. The model training method is nearly identical to that
described in the CIFAR-10 multi-GPU model training. Briefly, the model training
* Places an individual model replica on each GPU. Split the batch
across the GPUs.
* Updates model parameters synchronously by waiting for all GPUs to finish
processing a batch of data.
The training procedure is encapsulated by this diagram of how operations and
variables are placed on CPU and GPUs respecitvely.
<div style="width:40%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/Parallelism.png">
</div>
Each tower computes the gradients for a portion of the batch and the gradients
are combined and averaged across the multiple towers in order to provide a
single update of the Variables stored on the CPU.
A crucial aspect of training a network of this size is *training speed* in terms
of wall-clock time. The training speed is dictated by many factors -- most
importantly the batch size and the learning rate schedule. Both of these
parameters are heavily coupled to the hardware set up.
Generally speaking, a batch size is a difficult parameter to tune as it requires
balancing memory demands of the model, memory available on the GPU and speed
of computation. Generally speaking, employing larger batch sizes leads to
more efficient computation and potentially more efficient training steps.
We have tested several hardware setups for training this model from scratch but
we emphasize that depending your hardware set up, you may need to adapt the
batch size and learning rate schedule.
Please see the comments in `inception_train.py` for a few selected learning rate
plans based on some selected hardware setups.
To train this model, you simply need to specify the following:
```shell
# Build the training binary to run on a GPU. If you do not have a GPU,
# then exclude '--config=cuda'
bazel build -c opt --config=cuda inception/imagenet_train
# run it
bazel-bin/inception/imagenet_train.py --num_gpus=1 --batch_size=32 --train_dir=/tmp/imagenet_train --data_dir=/tmp/imagenet_data
```
The model reads in the ImageNet training data from `--data_dir`. If you followed
the instructions in [Getting Started](#getting-started), then set
`--data_dir="${DATA_DIR}"`. The script assumes that there exists a set of
sharded TFRecord files containing the ImageNet data. If you have not created
TFRecord files, please refer to [Getting Started](#getting-started)
Here is the output of the above command line when running on a Tesla K40c:
```shell
2016-03-07 12:24:59.922898: step 0, loss = 13.11 (5.3 examples/sec; 6.064 sec/batch)
2016-03-07 12:25:55.206783: step 10, loss = 13.71 (9.4 examples/sec; 3.394 sec/batch)
2016-03-07 12:26:28.905231: step 20, loss = 14.81 (9.5 examples/sec; 3.380 sec/batch)
2016-03-07 12:27:02.699719: step 30, loss = 14.45 (9.5 examples/sec; 3.378 sec/batch)
2016-03-07 12:27:36.515699: step 40, loss = 13.98 (9.5 examples/sec; 3.376 sec/batch)
2016-03-07 12:28:10.220956: step 50, loss = 13.92 (9.6 examples/sec; 3.327 sec/batch)
2016-03-07 12:28:43.658223: step 60, loss = 13.28 (9.6 examples/sec; 3.350 sec/batch)
...
```
This example highlights several important points:
* A log entry is printed every 10 step and the line includes the
total loss (starts around 13.0-14.0) and the speed of processing in terms
of throughput (examples / sec) and batch speed (sec/batch).
* The first step in training is always slow. The primary reason for the slow
speed is that during the first step of training, the preprocessing queue must
first fill up the several thousand example images in order to reach its minimum
capacity before training starts.
The number of GPU devices is specified by `--num_gpus` (which defaults to 1).
Specifying `--num_gpus` greater then 1 splits the batch evenly split across
the GPU cards.
```shell
# Build the training binary to run on a GPU. If you do not have a GPU,
# then exclude '--config=cuda'
bazel build -c opt --config=cuda inception/imagenet_train
# run it
bazel-bin/inception/imagenet_train --num_gpus=2 --batch_size=64 --train_dir=/tmp/imagenet_train
```
This model splits the batch of 64 images across 2 GPUs and calculates
the average gradient by waiting for both GPUs to finish calculating the
gradients from their respective data (See diagram above). Generally speaking,
using larger numbers of GPUs leads to higher throughput as well as the
opportunity to use larger batch sizes. In turn, larger batch sizes imply
better estimates of the gradient enabling the usage of higher learning rates.
In summary, using more GPUs results in simply faster training speed.
Note that selecting a batch size is a difficult parameter to tune as it requires
balancing memory demands of the model, memory available on the GPU and speed
of computation. Generally speaking, employing larger batch sizes leads to
more efficient computation and potentially more efficient training steps.
Note that there is considerable noise in the loss function on individual steps
in the previous log. Because of this noise, it is difficult to discern how well
a model is learning. The solution to the last problem is to launch TensorBoard
pointing to the directory containing the events log.
```shell
tensorboard --logdir=/tmp/imagenet_train
```
TensorBoard has access to the many Summaries produced by the model that
describe multitudes of statistics tracking the model behavior and the quality
of the learned model. In particular, TensorBoard tracks a exponentially smoothed
version of the loss. In practice, it is far easier to judge how well a model
learns by monitoring the smoothed version of the loss.
## How to Evaluate
Evaluating an Inception v3 model on the ImageNet 2012 validation data set
requires running a separate binary.
The evaluation procedure is nearly identical to [Evaluating a Model]
(https://www.tensorflow.org/tutorials/deep_cnn/index.html#evaluating-a-model)
described in the [Convolutional Neural Network](https://www.tensorflow.org/tutorials/deep_cnn/index.html)
tutorial.
**WARNING** Be careful not to run the evaluation and training binary on the
same GPU or else you might run out of memory. Consider running the evaluation on
a separate GPU if available or suspending the training binary while running
the evaluation on the same GPU.
Briefly, one can evaluate the model by running:
```shell
# Build the training binary to run on a GPU. If you do not have a GPU,
# then exclude '--config=cuda'
bazel build -c opt --config=cuda inception/imagenet_eval
# run it
bazel-bin/inception/imagenet_eval --checkpoint_dir=/tmp/imagenet_train --eval_dir=/tmp/imagenet_eval
```
Note that we point ``--checkpoint_dir`` to the location of the checkpoints
saved by `inception_train.py` above. Running the above command results in the
following output:
```shell
2016-02-17 22:32:50.391206: precision @ 1 = 0.735
...
```
The script calculates the precision @ 1 over the entire validation data
periodically. The precision @ 1 measures the how often the highest scoring
prediction from the model matched the ImageNet label -- in this case, 73.5%.
If you wish to run the eval just once and not periodically, append the
`--run_once` option.
Much like the training script, `imagenet_eval.py` also
exports summaries that may be visualized in TensorBoard. These summaries
calculate additional statistics on the predictions (e.g. recall @ 5) as well
as monitor the statistics of the model activations and weights during
evaluation.
## How to Fine-Tune a Pre-Trained Model on a New Task
### Getting Started
Much like training the ImageNet model we must first convert a new data set to
the sharded TFRecord format which each entry is a serialized `tf.Example` proto.
We have provided a script demonstrating how to do this for small data set of
of a few thousand flower images spread across 5 labels:
```shell
daisy, dandelion, roses, sunflowers, tulips
```
There is a single automated script that downloads the data set and converts
it to the TFRecord format. Much like the ImageNet data set, each record in the
TFRecord format is a serialized `tf.Example` proto whose entries include
a JPEG-encoded string and an integer label. Please see
[`parse_example_proto`](image_processing.py) for details.
The script just takes a few minutes to run depending your network connection
speed for downloading and processing the images. Your hard disk requires 200MB
of free storage. Here we select `DATA_DIR=$HOME/flowers-data` as such a
location but feel free to edit accordingly.
```shell
# location of where to place the flowers data
FLOWERS_DATA_DIR=$HOME/flowers-data
# build the preprocessing script.
bazel build -c opt inception/download_and_preprocess_flowers
# run it
bazel-bin/inception/download_and_preprocess_flowers "${FLOWERS_DATA_DIR}$"
```
If the script runs successfully, the final line of the terminal output should
look like:
```shell
2016-02-24 20:42:25.067551: Finished writing all 3170 images in data set.
```
When the script finishes you will find 2 shards for the training and validation
files in the `DATA_DIR`. The files will match the patterns
`train-????-of-00001` and `validation-?????-of-00001`, respectively.
**NOTE** If you wish to prepare a custom image data set for transfer learning,
you will need to invoke [`build_image_data.py`](data/build_image_data.py)
on your custom data set.
Please see the associated options and assumptions behind this script by reading
the comments section of [`build_image_data.py`](data/build_image_data.py).
The second piece you will need is a trained Inception v3 image model. You have
the option of either training one yourself (See
[How to Train from Scratch](#how-to-train-from-scratch) for details) or you can
download a pre-trained model like so:
```shell
# location of where to place the Inception v3 model
DATA_DIR=$HOME/inception-v3-model
cd ${DATA_DIR}
# download the Inception v3 model
curl -O http://download.tensorflow.org/models/image/imagenet/inception-v3-2016-03-01.tar.gz
tar xzf inception-v3-2016-03-01.tar.gz
# this will create a directory called inception-v3 which contains the following files.
> ls inception-v3
README.txt
checkpoint
model.ckpt-157585
```
[Congratulations!](https://www.youtube.com/watch?v=9bZkp7q19f0)
You are now ready to fine-tune your pre-trained Inception v3 model
with the flower data set.
### How to Retrain a Trained Model on the Flowers Data
We are now ready to fine-tune a pre-trained Inception-v3 model on
the flowers data set. This requires two distinct changes to our training
procedure:
1. Build the exact same model as previously except we change the number
of labels in the final classification layer.
2. Restore all weights from the pre-trained Inception-v3 except for the
final classification layer; this will get randomly initialized instead.
We can perform these two operations by specifying two flags:
`--pretrained_model_checkpoint_path` and `--fine_tune`.
The first flag is a string that points to the path of a pre-trained Inception-v3
model. If this flag is specified, it will load the entire model from the
checkpoint before the script begins training.
The second flag `--fine_tune` is a boolean that indicates whether the last
classification layer should be randomly initialized or restored.
You may set this flag to false
if you wish to continue training a pre-trained model from a checkpoint. If you
set this flag to true, you can train a new classification layer from scratch.
In order to understand how `--fine_tune` works, please see the discussion
on `Variables` in the TensorFlow-Slim [`README.md`](slim/README.md).
Putting this all together you can retrain a pre-trained Inception-v3 model
on the flowers data set with the following command.
```shell
# Build the training binary to run on a GPU. If you do not have a GPU,
# then exclude '--config=cuda'
bazel build -c opt --config=cuda inception/flowers_train
# Path to the downloaded Inception-v3 model.
MODEL_PATH="${INCEPTION_MODEL_DIR}/model.ckpt-157585"
# Directory where the flowers data resides.
FLOWERS_DATA_DIR=/tmp/flowers-data/
# Directory where to save the checkpoint and events files.
TRAIN_DIR=/tmp/flowers_train/
# Run the fine-tuning on the flowers data set starting from the pre-trained
# Imagenet-v3 model.
bazel-bin/inception/flowers_train \
--train_dir="${TRAIN_DIR}" \
--data_dir="${FLOWERS_DATA_DIR}" \
--pretrained_model_checkpoint_path="${MODEL_PATH}" \
--fine_tune=True \
--initial_learning_rate=0.001 \
--input_queue_memory_factor=1
```
We have added a few extra options to the training procedure.
* Fine-tuning a model a separate data set requires significantly lowering the
initial learning rate. We set the initial learning rate to 0.001.
* The flowers data set is quite small so we shrink the size of the shuffling
queue of examples. See [Adjusting Memory Demands](#adjusting-memory-demands)
for more details.
The training script will only reports the loss. To evaluate the quality of the
fine-tuned model, you will need to run `flowers_eval`:
```shell
# Build the training binary to run on a GPU. If you do not have a GPU,
# then exclude '--config=cuda'
bazel build -c opt --config=cuda inception/flowers_eval
# Directory where we saved the fine-tuned checkpoint and events files.
TRAIN_DIR=/tmp/flowers_train/
# Directory where the flowers data resides.
FLOWERS_DATA_DIR=/tmp/flowers-data/
# Directory where to save the evaluation events files.
EVAL_DIR=/tmp/flowers_eval/
# Evaluate the fine-tuned model on a hold-out of the flower data set.
blaze-bin/third_party/tensorflow_models/inception/flowers_eval \
--eval_dir="${EVAL_DIR}" \
--data_dir="${FLOWERS_DATA_DIR}" \
--subset=validation \
--num_examples=500 \
--checkpoint_dir="${TRAIN_DIR}" \
--input_queue_memory_factfor=1 \
--run_once
```
We find that the evaluation arrives at roughly 93.4% precision@1 after the
model has been running for 2000 steps.
```shell
Succesfully loaded model from /tmp/flowers/model.ckpt-1999 at step=1999.
2016-03-01 16:52:51.761219: starting evaluation on (validation).
2016-03-01 16:53:05.450419: [20 batches out of 20] (36.5 examples/sec; 0.684sec/batch)
2016-03-01 16:53:05.450471: precision @ 1 = 0.9340 recall @ 5 = 0.9960 [500 examples]
```
## How to Construct a New Dataset for Retraining
One can use the existing scripts supplied with this model to build a new
dataset for training or fine-tuning. The main script to employ is
[`build_image_data.py`](./build_image_data.py). Briefly,
this script takes a structured
directory of images and converts it to a sharded `TFRecord` that can be read
by the Inception model.
In particular, you will need to create a directory of training images
that reside within `$TRAIN_DIR` and `$VALIDATION_DIR` arranged as such:
```shell
$TRAIN_DIR/dog/image0.jpeg
$TRAIN_DIR/dog/image1.jpg
$TRAIN_DIR/dog/image2.png
...
$TRAIN_DIR/cat/weird-image.jpeg
$TRAIN_DIR/cat/my-image.jpeg
$TRAIN_DIR/cat/my-image.JPG
...
$VALIDATION_DIR/dog/imageA.jpeg
$VALIDATION_DIR/dog/imageB.jpg
$VALIDATION_DIR/dog/imageC.png
...
$VALIDATION_DIR/cat/weird-image.PNG
$VALIDATION_DIR/cat/that-image.jpg
$VALIDATION_DIR/cat/cat.JPG
...
```
Each sub-directory in `$TRAIN_DIR` and `$VALIDATION_DIR` corresponds to a
unique label for the images that reside within that sub-directory. The images
may be JPEG or PNG images. We do not support other images types currently.
Once the data is arranged in this directory structure, we can run
`build_image_data.py` on the data to generate the sharded `TFRecord` dataset.
Each entry of the `TFRecord` is a serialized `tf.Example` protocol buffer.
A complete list of information contained in the `tf.Example` is described
in the comments of `build_image_data.py`.
To run `build_image_data.py`, you can run the following command line:
```shell
# location to where to save the TFRecord data.
OUTPUT_DIRECTORY=$HOME/my-custom-data/
# build the preprocessing script.
bazel build -c opt inception/build_image_data
# convert the data.
bazel-bin/inception/build_image_data \
--train_directory="${TRAIN_DIR}" \
--validation_directory="${VALIDATION_DIR}" \
--output_directory="${OUTPUT_DIRECTORY}" \
--labels_file="${LABELS_FILE}" \
--train_shards=128 \
--validation_shards=24 \
--num_threads=8
```
where the `$OUTPUT_DIRECTORY` is the location of the sharded `TFRecords`. The
`$LABELS_FILE` will be a text file that is outputted by the script that
provides a list of all of the labels. For instance, in the case flowers data
set, the `$LABELS_FILE` contained the following data:
```shell
daisy
dandelion
roses
sunflowers
tulips
```
Note that each row of each label corresponds with the entry in the final
classifier in the model. That is, the `daisy` corresponds to the classifier
for entry `1`; `dandelion` is entry `2`, etc. We skip label `0` as a
background class.
After running this script produces files that look like the following:
```shell
$TRAIN_DIR/train-00000-of-00024
$TRAIN_DIR/train-00001-of-00024
...
$TRAIN_DIR/train-00023-of-00024
and
$VALIDATION_DIR/validation-00000-of-00008
$VALIDATION_DIR/validation-00001-of-00008
...
$VALIDATION_DIR/validation-00007-of-00008
```
where 24 and 8 are the number of shards specified for each
dataset, respectively. Generally speaking, we aim for selecting the number
of shards such that roughly 1024 images reside in each shard. One this
data set is built you are ready to train or fine-tune an Inception model
on this data set.
## Practical Considerations for Training a Model
The model architecture and training procedure is heavily dependent on the
hardware used to train the model. If you wish to train or fine-tune this
model on your machine **you will need to adjust and empirically determine
a good set of training hyper-parameters for your setup**. What follows are
some general considerations for novices.
### Finding Good Hyperparameters
Roughly 5-10 hyper-parameters govern the speed at which a network is trained.
In addition to `--batch_size` and `--num_gpus`, there are several constants
defined in [inception_train.py](./inception_train.py) which dictate the
learning schedule.
```shell
RMSPROP_DECAY = 0.9 # Decay term for RMSProp.
MOMENTUM = 0.9 # Momentum in RMSProp.
RMSPROP_EPSILON = 1.0 # Epsilon term for RMSProp.
INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
NUM_EPOCHS_PER_DECAY = 30.0 # Epochs after which learning rate decays.
LEARNING_RATE_DECAY_FACTOR = 0.16 # Learning rate decay factor.
```
There are many papers tha discuss the various tricks and trade-offs associated
with training a model with stochastic gradient descent. For those new to the
field, some great references are:
* Y Bengio, [Practical recommendations for gradient-based training of deep architectures](http://arxiv.org/abs/1206.5533)
* I Goodfellow, Y Bengio and A Courville, [Deep Learning](http://www.deeplearningbook.org/)
What follows is a summary of some general advice for identifying appropriate
model hyper-parameters in the context of this particular
model training setup. Namely,
this library provides *sycnhronous* updates to model parameters based on
batch-splitting the model across multiple GPUs.
* Higher learning rates leads to faster training. Too high of learning rate
leads to instability and will cause model parameters to diverge to infinity
or NaN.
* Larger batch sizes lead to higher quality estimates of the gradient and
permit training the model with higher learning rates.
* Often the GPU memory is a bottleneck that prevents employing larger batch
sizes. Employing more GPUs allows one to user larger batch sizes because
this model splits the batch across the GPUs.
**NOTE** If one wishes to train this model with *asynchronous* gradient updates,
one will need to substantially alter this model and new considerations need to
be factored into hyperparameter tuning.
See [Large Scale Distributed Deep Networks](http://research.google.com/archive/large_deep_networks_nips2012.html)
for a discussion in this domain.
### Adjusting Memory Demands
Training this model has large memory demands in terms of the CPU and GPU. Let's
discuss each item in turn.
GPU memory is relatively small compared to CPU memory. Two items dictate the
amount of GPU memory employed -- model architecture and batch size. Assuming
that you keep the model architecture fixed, the sole parameter governing the
GPU demand is the batch size. A good rule of thumb is to try employ as large
of batch size as will fit on the GPU.
If you run out of GPU memory, either lower the `--batch_size` or employ more
GPUs on your desktop. The model performs batch-splitting across GPUs, thus N
GPUs can handle N times the batch size of 1 GPU.
The model requires a large amount of CPU memory as well. We have tuned the model
to employ about ~40GB of CPU memory. Thus, having access to 64 or 128 GB of
CPU memory would be ideal.
If that is not possible, you can tune down the memory demands of the model
via lowering `--input_queue_memory_factor`. Images are preprocessed
asyncronously with respect to the main training across
`--num_preprocess_threads` threads. The preprocessed images are stored in
shuffling queue in which each GPU performs a dequeue operation in order
to receive a `batch_size` worth of images.
In order to guarantee good shuffling across the data, we maintain a large
shuffling queue of 1024 x `input_queue_memory_factor` images. For the current
model architecture, this corresponds to 16GB of CPU memory. You may lower
`input_queue_memory_factor` in order to decrease the memory footprint. Keep
in mind though that lowering this value drastically may result in a model
with slighlty lower predictive accuracy when training from scratch. Please
see comments in [`image_processing.py`](./image_processing.py) for more details.
## Troubleshooting
#### The model runs out of CPU memory.
In lieu of buying more CPU memory, an easy fix is to
decrease `--input_queue_memory_factor`. See
[Adjusting Memory Demands](#adjusting-memory-demands).
#### The model runs out of GPU memory.
The data is not able to fit on the GPU card. The simplest solution is to
decrease the batch size of the model. Otherwise, you will need to think about
a more sophisticated method for specifying the training which cuts up the model
across multiple `session.run()` calls or partitions the model across multiple
GPUs. See [Using GPUs](https://www.tensorflow.org/versions/r0.7/how_tos/using_gpu/index.html)
and
[Adjusting Memory Demands](#adjusting-memory-demands)
for more information.
#### The model training results in NaN's.
The learning rate of the model is too high. Turn down your learning rate.
#### I wish to train a model with a different image size.
The simplest solution is to artificially resize your images to `299x299`
pixels. See
[Images](https://www.tensorflow.org/versions/r0.7/api_docs/python/image.html)
section for many resizing, cropping and padding methods.
Note that the entire model architecture is predicated on a `299x299` image,
thus if you wish to change the input image size, then you may need to redesign
the entire model architecture.
#### What hardware specification are these hyper-parameters targeted for?
We targeted a desktop with 128GB of CPU ram connected to 8 NVIDIA Tesla K40
GPU cards but we have run this on desktops with 32GB of CPU ram and 1 NVIDIA
Tesla K40. You can get a sense of the various training configurations we
tested by reading the comments in [`inception_train.py`](./inception_train.py).
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts image data to TFRecords file format with Example protos.
The image data set is expected to reside in JPEG files located in the
following directory structure.
data_dir/label_0/image0.jpeg
data_dir/label_0/image1.jpg
...
data_dir/label_1/weird-image.jpeg
data_dir/label_1/my-image.jpeg
...
where the sub-sirectory is the unique label associated with these images.
This TensorFlow script converts the training and evaluation data into
a sharded data set consisting of TFRecord files
train_directory/train-00000-of-01024
train_directory/train-00001-of-01024
...
train_directory/train-00127-of-01024
and
validation_directory/validation-00000-of-00128
validation_directory/validation-00001-of-00128
...
validation_directory/validation-00127-of-00128
where we have selected 1024 and 128 shards for each data set. Each record
within the TFRecord file is a serialized Example proto. The Example proto
contains the following fields:
image/encoded: string containing JPEG encoded image in RGB colorspace
image/height: integer, image height in pixels
image/width: integer, image width in pixels
image/colorspace: string, specifying the colorspace, always 'RGB'
image/channels: integer, specifying the number of channels, always 3
image/format: string, specifying the format, always'JPEG'
image/filename: string containing the basename of the image file
e.g. 'n01440764_10026.JPEG' or 'ILSVRC2012_val_00000293.JPEG'
image/class/label: integer specifying the index in a classification layer.
The label ranges from [0, num_labels] where 0 is unused and left as
the background class.
image/class/text: string specifying the human-readable version of the label
e.g. 'dog'
If you data set involves bounding boxes, please look at build_imagenet_data.py.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import os
import random
import sys
import threading
import numpy as np
import tensorflow as tf
tf.app.flags.DEFINE_string('train_directory', '/tmp/',
'Training data directory')
tf.app.flags.DEFINE_string('validation_directory', '/tmp/',
'Validation data directory')
tf.app.flags.DEFINE_string('output_directory', '/tmp/',
'Output data directory')
tf.app.flags.DEFINE_integer('train_shards', 2,
'Number of shards in training TFRecord files.')
tf.app.flags.DEFINE_integer('validation_shards', 2,
'Number of shards in validation TFRecord files.')
tf.app.flags.DEFINE_integer('num_threads', 2,
'Number of threads to preprocess the images.')
# The labels file contains a list of valid labels are held in this file.
# Assumes that the file contains entries as such:
# dog
# cat
# flower
# where each line corresponds to a label. We map each label contained in
# the file to an integer corresponding to the line number starting from 0.
tf.app.flags.DEFINE_string('labels_file', '', 'Labels file')
FLAGS = tf.app.flags.FLAGS
def _int64_feature(value):
"""Wrapper for inserting int64 features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _convert_to_example(filename, image_buffer, label, text, height, width):
"""Build an Example proto for an example.
Args:
filename: string, path to an image file, e.g., '/path/to/example.JPG'
image_buffer: string, JPEG encoding of RGB image
label: integer, identifier for the ground truth for the network
text: string, unique human-readable, e.g. 'dog'
height: integer, image height in pixels
width: integer, image width in pixels
Returns:
Example proto
"""
colorspace = 'RGB'
channels = 3
image_format = 'JPEG'
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': _int64_feature(height),
'image/width': _int64_feature(width),
'image/colorspace': _bytes_feature(colorspace),
'image/channels': _int64_feature(channels),
'image/class/label': _int64_feature(label),
'image/class/text': _bytes_feature(text),
'image/format': _bytes_feature(image_format),
'image/filename': _bytes_feature(os.path.basename(filename)),
'image/encoded': _bytes_feature(image_buffer)}))
return example
class ImageCoder(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Create a single Session to run all image coding calls.
self._sess = tf.Session()
# Initializes function that converts PNG to JPEG data.
self._png_data = tf.placeholder(dtype=tf.string)
image = tf.image.decode_png(self._png_data, channels=3)
self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def png_to_jpeg(self, image_data):
return self._sess.run(self._png_to_jpeg,
feed_dict={self._png_data: image_data})
def decode_jpeg(self, image_data):
image = self._sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _is_png(filename):
"""Determine if a file contains a PNG format image.
Args:
filename: string, path of the image file.
Returns:
boolean indicating if the image is a PNG.
"""
return '.png' in filename
def _process_image(filename, coder):
"""Process a single image file.
Args:
filename: string, path to an image file e.g., '/path/to/example.JPG'.
coder: instance of ImageCoder to provide TensorFlow image coding utils.
Returns:
image_buffer: string, JPEG encoding of RGB image.
height: integer, image height in pixels.
width: integer, image width in pixels.
"""
# Read the image file.
image_data = tf.gfile.FastGFile(filename, 'r').read()
# Convert any PNG to JPEG's for consistency.
if _is_png(filename):
print('Converting PNG to JPEG for %s' % filename)
image_data = coder.png_to_jpeg(image_data)
# Decode the RGB JPEG.
image = coder.decode_jpeg(image_data)
# Check that image converted to RGB
assert len(image.shape) == 3
height = image.shape[0]
width = image.shape[1]
assert image.shape[2] == 3
return image_data, height, width
def _process_image_files_batch(coder, thread_index, ranges, name, filenames,
texts, labels, num_shards):
"""Processes and saves list of images as TFRecord in 1 thread.
Args:
coder: instance of ImageCoder to provide TensorFlow image coding utils.
thread_index: integer, unique batch to run index is within [0, len(ranges)).
ranges: list of pairs of integers specifying ranges of each batches to
analyze in parallel.
name: string, unique identifier specifying the data set
filenames: list of strings; each string is a path to an image file
texts: list of strings; each string is human readable, e.g. 'dog'
labels: list of integer; each integer identifies the ground truth
num_shards: integer number of shards for this data set.
"""
# Each thread produces N shards where N = int(num_shards / num_threads).
# For instance, if num_shards = 128, and the num_threads = 2, then the first
# thread would produce shards [0, 64).
num_threads = len(ranges)
assert not num_shards % num_threads
num_shards_per_batch = int(num_shards / num_threads)
shard_ranges = np.linspace(ranges[thread_index][0],
ranges[thread_index][1],
num_shards_per_batch + 1).astype(int)
num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
counter = 0
for s in xrange(num_shards_per_batch):
# Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
shard = thread_index * num_shards_per_batch + s
output_filename = '%s-%.5d-of-%.5d' % (name, shard, num_shards)
output_file = os.path.join(FLAGS.output_directory, output_filename)
writer = tf.python_io.TFRecordWriter(output_file)
shard_counter = 0
files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
for i in files_in_shard:
filename = filenames[i]
label = labels[i]
text = texts[i]
image_buffer, height, width = _process_image(filename, coder)
example = _convert_to_example(filename, image_buffer, label,
text, height, width)
writer.write(example.SerializeToString())
shard_counter += 1
counter += 1
if not counter % 1000:
print('%s [thread %d]: Processed %d of %d images in thread batch.' %
(datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()
print('%s [thread %d]: Wrote %d images to %s' %
(datetime.now(), thread_index, shard_counter, output_file))
sys.stdout.flush()
shard_counter = 0
print('%s [thread %d]: Wrote %d images to %d shards.' %
(datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()
def _process_image_files(name, filenames, texts, labels, num_shards):
"""Process and save list of images as TFRecord of Example protos.
Args:
name: string, unique identifier specifying the data set
filenames: list of strings; each string is a path to an image file
texts: list of strings; each string is human readable, e.g. 'dog'
labels: list of integer; each integer identifies the ground truth
num_shards: integer number of shards for this data set.
"""
assert len(filenames) == len(texts)
assert len(filenames) == len(labels)
# Break all images into batches with a [ranges[i][0], ranges[i][1]].
spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int)
ranges = []
threads = []
for i in xrange(len(spacing) - 1):
ranges.append([spacing[i], spacing[i+1]])
# Launch a thread for each batch.
print('Launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges))
sys.stdout.flush()
# Create a mechanism for monitoring when all threads are finished.
coord = tf.train.Coordinator()
# Create a generic TensorFlow-based utility for converting all image codings.
coder = ImageCoder()
threads = []
for thread_index in xrange(len(ranges)):
args = (coder, thread_index, ranges, name, filenames,
texts, labels, num_shards)
t = threading.Thread(target=_process_image_files_batch, args=args)
t.start()
threads.append(t)
# Wait for all the threads to terminate.
coord.join(threads)
print('%s: Finished writing all %d images in data set.' %
(datetime.now(), len(filenames)))
sys.stdout.flush()
def _find_image_files(data_dir, labels_file):
"""Build a list of all images files and labels in the data set.
Args:
data_dir: string, path to the root directory of images.
Assumes that the image data set resides in JPEG files located in
the following directory structure.
data_dir/dog/another-image.JPEG
data_dir/dog/my-image.jpg
where 'dog' is the label associated with these images.
labels_file: string, path to the labels file.
The list of valid labels are held in this file. Assumes that the file
contains entries as such:
dog
cat
flower
where each line corresponds to a label. We map each label contained in
the file to an integer starting with the integer 0 corresponding to the
label contained in the first line.
Returns:
filenames: list of strings; each string is a path to an image file.
texts: list of strings; each string is the class, e.g. 'dog'
labels: list of integer; each integer identifies the ground truth.
"""
print('Determining list of input files and labels from %s.' % data_dir)
unique_labels = [l.strip() for l in tf.gfile.FastGFile(
labels_file, 'r').readlines()]
labels = []
filenames = []
texts = []
# Leave label index 0 empty as a background class.
label_index = 1
# Construct the list of JPEG files and labels.
for text in unique_labels:
jpeg_file_path = '%s/%s/*' % (data_dir, text)
matching_files = tf.gfile.Glob(jpeg_file_path)
labels.extend([label_index] * len(matching_files))
texts.extend([text] * len(matching_files))
filenames.extend(matching_files)
if not label_index % 100:
print('Finished finding files in %d of %d classes.' % (
label_index, len(labels)))
label_index += 1
# Shuffle the ordering of all image files in order to guarantee
# random ordering of the images with respect to label in the
# saved TFRecord files. Make the randomization repeatable.
shuffled_index = range(len(filenames))
random.seed(12345)
random.shuffle(shuffled_index)
filenames = [filenames[i] for i in shuffled_index]
texts = [texts[i] for i in shuffled_index]
labels = [labels[i] for i in shuffled_index]
print('Found %d JPEG files across %d labels inside %s.' %
(len(filenames), len(unique_labels), data_dir))
return filenames, texts, labels
def _process_dataset(name, directory, num_shards, labels_file):
"""Process a complete data set and save it as a TFRecord.
Args:
name: string, unique identifier specifying the data set.
directory: string, root path to the data set.
num_shards: integer number of shards for this data set.
labels_file: string, path to the labels file.
"""
filenames, texts, labels = _find_image_files(directory, labels_file)
_process_image_files(name, filenames, texts, labels, num_shards)
def main(unused_argv):
assert not FLAGS.train_shards % FLAGS.num_threads, (
'Please make the FLAGS.num_threads commensurate with FLAGS.train_shards')
assert not FLAGS.validation_shards % FLAGS.num_threads, (
'Please make the FLAGS.num_threads commensurate with '
'FLAGS.validation_shards')
print('Saving results to %s' % FLAGS.output_directory)
# Run it!
_process_dataset('validation', FLAGS.validation_directory,
FLAGS.validation_shards, FLAGS.labels_file)
_process_dataset('train', FLAGS.train_directory,
FLAGS.train_shards, FLAGS.labels_file)
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts ImageNet data to TFRecords file format with Example protos.
The raw ImageNet data set is expected to reside in JPEG files located in the
following directory structure.
data_dir/n01440764/ILSVRC2012_val_00000293.JPEG
data_dir/n01440764/ILSVRC2012_val_00000543.JPEG
...
where 'n01440764' is the unique synset label associated with
these images.
The training data set consists of 1000 sub-directories (i.e. labels)
each containing 1200 JPEG images for a total of 1.2M JPEG images.
The evaluation data set consists of 1000 sub-directories (i.e. labels)
each containing 50 JPEG images for a total of 50K JPEG images.
This TensorFlow script converts the training and evaluation data into
a sharded data set consisting of 1024 and 128 TFRecord files, respectively.
train_directory/train-00000-of-01024
train_directory/train-00001-of-01024
...
train_directory/train-00127-of-01024
and
validation_directory/validation-00000-of-00128
validation_directory/validation-00001-of-00128
...
validation_directory/validation-00127-of-00128
Each validation TFRecord file contains ~390 records. Each training TFREcord
file contains ~1250 records. Each record within the TFRecord file is a
serialized Example proto. The Example proto contains the following fields:
image/encoded: string containing JPEG encoded image in RGB colorspace
image/height: integer, image height in pixels
image/width: integer, image width in pixels
image/colorspace: string, specifying the colorspace, always 'RGB'
image/channels: integer, specifying the number of channels, always 3
image/format: string, specifying the format, always'JPEG'
image/filename: string containing the basename of the image file
e.g. 'n01440764_10026.JPEG' or 'ILSVRC2012_val_00000293.JPEG'
image/class/label: integer specifying the index in a classification layer.
The label ranges from [1, 1000] where 0 is not used.
image/class/synset: string specifying the unique ID of the label,
e.g. 'n01440764'
image/class/text: string specifying the human-readable version of the label
e.g. 'red fox, Vulpes vulpes'
image/object/bbox/xmin: list of integers specifying the 0+ human annotated
bounding boxes
image/object/bbox/xmax: list of integers specifying the 0+ human annotated
bounding boxes
image/object/bbox/ymin: list of integers specifying the 0+ human annotated
bounding boxes
image/object/bbox/ymax: list of integers specifying the 0+ human annotated
bounding boxes
image/object/bbox/label: integer specifying the index in a classification
layer. The label ranges from [1, 1000] where 0 is not used. Note this is
always identical to the image label.
Note that the length of xmin is identical to the length of xmax, ymin and ymax
for each example.
Running this script using 16 threads may take around ~2.5 hours on a HP Z420.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
import os
import random
import sys
import threading
import numpy as np
import tensorflow as tf
tf.app.flags.DEFINE_string('train_directory', '/tmp/',
'Training data directory')
tf.app.flags.DEFINE_string('validation_directory', '/tmp/',
'Validation data directory')
tf.app.flags.DEFINE_string('output_directory', '/tmp/',
'Output data directory')
tf.app.flags.DEFINE_integer('train_shards', 1024,
'Number of shards in training TFRecord files.')
tf.app.flags.DEFINE_integer('validation_shards', 128,
'Number of shards in validation TFRecord files.')
tf.app.flags.DEFINE_integer('num_threads', 8,
'Number of threads to preprocess the images.')
# The labels file contains a list of valid labels are held in this file.
# Assumes that the file contains entries as such:
# n01440764
# n01443537
# n01484850
# where each line corresponds to a label expressed as a synset. We map
# each synset contained in the file to an integer (based on the alphabetical
# ordering). See below for details.
tf.app.flags.DEFINE_string('labels_file',
'imagenet_lsvrc_2015_synsets.txt',
'Labels file')
# This file containing mapping from synset to human-readable label.
# Assumes each line of the file looks like:
#
# n02119247 black fox
# n02119359 silver fox
# n02119477 red fox, Vulpes fulva
#
# where each line corresponds to a unique mapping. Note that each line is
# formatted as <synset>\t<human readable label>.
tf.app.flags.DEFINE_string('imagenet_metadata_file',
'imagenet_metadata.txt',
'ImageNet metadata file')
# This file is the output of process_bounding_box.py
# Assumes each line of the file looks like:
#
# n00007846_64193.JPEG,0.0060,0.2620,0.7545,0.9940
#
# where each line corresponds to one bounding box annotation associated
# with an image. Each line can be parsed as:
#
# <JPEG file name>, <xmin>, <ymin>, <xmax>, <ymax>
#
# Note that there might exist mulitple bounding box annotations associated
# with an image file.
tf.app.flags.DEFINE_string('bounding_box_file',
'./imagenet_2012_bounding_boxes.csv',
'Bounding box file')
FLAGS = tf.app.flags.FLAGS
def _int64_feature(value):
"""Wrapper for inserting int64 features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _float_feature(value):
"""Wrapper for inserting float features into Example proto."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _convert_to_example(filename, image_buffer, label, synset, human, bbox,
height, width):
"""Build an Example proto for an example.
Args:
filename: string, path to an image file, e.g., '/path/to/example.JPG'
image_buffer: string, JPEG encoding of RGB image
label: integer, identifier for the ground truth for the network
synset: string, unique WordNet ID specifying the label, e.g., 'n02323233'
human: string, human-readable label, e.g., 'red fox, Vulpes vulpes'
bbox: list of bounding boxes; each box is a list of integers
specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong to
the same label as the image label.
height: integer, image height in pixels
width: integer, image width in pixels
Returns:
Example proto
"""
xmin = []
ymin = []
xmax = []
ymax = []
for b in bbox:
assert len(b) == 4
# pylint: disable=expression-not-assigned
[l.append(point) for l, point in zip([xmin, ymin, xmax, ymax], b)]
# pylint: enable=expression-not-assigned
colorspace = 'RGB'
channels = 3
image_format = 'JPEG'
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': _int64_feature(height),
'image/width': _int64_feature(width),
'image/colorspace': _bytes_feature(colorspace),
'image/channels': _int64_feature(channels),
'image/class/label': _int64_feature(label),
'image/class/synset': _bytes_feature(synset),
'image/class/text': _bytes_feature(human),
'image/object/bbox/xmin': _float_feature(xmin),
'image/object/bbox/xmax': _float_feature(xmax),
'image/object/bbox/ymin': _float_feature(ymin),
'image/object/bbox/ymax': _float_feature(ymax),
'image/object/bbox/label': _int64_feature([label] * len(xmin)),
'image/format': _bytes_feature(image_format),
'image/filename': _bytes_feature(os.path.basename(filename)),
'image/encoded': _bytes_feature(image_buffer)}))
return example
class ImageCoder(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Create a single Session to run all image coding calls.
self._sess = tf.Session()
# Initializes function that converts PNG to JPEG data.
self._png_data = tf.placeholder(dtype=tf.string)
image = tf.image.decode_png(self._png_data, channels=3)
self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)
# Initializes function that converts CMYK JPEG data to RGB JPEG data.
self._cmyk_data = tf.placeholder(dtype=tf.string)
image = tf.image.decode_jpeg(self._cmyk_data, channels=0)
self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100)
# Initializes function that decodes RGB JPEG data.
self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
def png_to_jpeg(self, image_data):
return self._sess.run(self._png_to_jpeg,
feed_dict={self._png_data: image_data})
def cmyk_to_rgb(self, image_data):
return self._sess.run(self._cmyk_to_rgb,
feed_dict={self._cmyk_data: image_data})
def decode_jpeg(self, image_data):
image = self._sess.run(self._decode_jpeg,
feed_dict={self._decode_jpeg_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _is_png(filename):
"""Determine if a file contains a PNG format image.
Args:
filename: string, path of the image file.
Returns:
boolean indicating if the image is a PNG.
"""
# File list from:
# https://groups.google.com/forum/embed/?place=forum/torch7#!topic/torch7/fOSTXHIESSU
return 'n02105855_2933.JPEG' in filename
def _is_cmyk(filename):
"""Determine if file contains a CMYK JPEG format image.
Args:
filename: string, path of the image file.
Returns:
boolean indicating if the image is a JPEG encoded with CMYK color space.
"""
# File list from:
# https://github.com/cytsai/ilsvrc-cmyk-image-list
blacklist = ['n01739381_1309.JPEG', 'n02077923_14822.JPEG',
'n02447366_23489.JPEG', 'n02492035_15739.JPEG',
'n02747177_10752.JPEG', 'n03018349_4028.JPEG',
'n03062245_4620.JPEG', 'n03347037_9675.JPEG',
'n03467068_12171.JPEG', 'n03529860_11437.JPEG',
'n03544143_17228.JPEG', 'n03633091_5218.JPEG',
'n03710637_5125.JPEG', 'n03961711_5286.JPEG',
'n04033995_2932.JPEG', 'n04258138_17003.JPEG',
'n04264628_27969.JPEG', 'n04336792_7448.JPEG',
'n04371774_5854.JPEG', 'n04596742_4225.JPEG',
'n07583066_647.JPEG', 'n13037406_4650.JPEG']
return filename.split('/')[-1] in blacklist
def _process_image(filename, coder):
"""Process a single image file.
Args:
filename: string, path to an image file e.g., '/path/to/example.JPG'.
coder: instance of ImageCoder to provide TensorFlow image coding utils.
Returns:
image_buffer: string, JPEG encoding of RGB image.
height: integer, image height in pixels.
width: integer, image width in pixels.
"""
# Read the image file.
image_data = tf.gfile.FastGFile(filename, 'r').read()
# Clean the dirty data.
if _is_png(filename):
# 1 image is a PNG.
print('Converting PNG to JPEG for %s' % filename)
image_data = coder.png_to_jpeg(image_data)
elif _is_cmyk(filename):
# 22 JPEG images are in CMYK colorspace.
print('Converting CMYK to RGB for %s' % filename)
image_data = coder.cmyk_to_rgb(image_data)
# Decode the RGB JPEG.
image = coder.decode_jpeg(image_data)
# Check that image converted to RGB
assert len(image.shape) == 3
height = image.shape[0]
width = image.shape[1]
assert image.shape[2] == 3
return image_data, height, width
def _process_image_files_batch(coder, thread_index, ranges, name, filenames,
synsets, labels, humans, bboxes, num_shards):
"""Processes and saves list of images as TFRecord in 1 thread.
Args:
coder: instance of ImageCoder to provide TensorFlow image coding utils.
thread_index: integer, unique batch to run index is within [0, len(ranges)).
ranges: list of pairs of integers specifying ranges of each batches to
analyze in parallel.
name: string, unique identifier specifying the data set
filenames: list of strings; each string is a path to an image file
synsets: list of strings; each string is a unique WordNet ID
labels: list of integer; each integer identifies the ground truth
humans: list of strings; each string is a human-readable label
bboxes: list of bounding boxes for each image. Note that each entry in this
list might contain from 0+ entries corresponding to the number of bounding
box annotations for the image.
num_shards: integer number of shards for this data set.
"""
# Each thread produces N shards where N = int(num_shards / num_threads).
# For instance, if num_shards = 128, and the num_threads = 2, then the first
# thread would produce shards [0, 64).
num_threads = len(ranges)
assert not num_shards % num_threads
num_shards_per_batch = int(num_shards / num_threads)
shard_ranges = np.linspace(ranges[thread_index][0],
ranges[thread_index][1],
num_shards_per_batch + 1).astype(int)
num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
counter = 0
for s in xrange(num_shards_per_batch):
# Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
shard = thread_index * num_shards_per_batch + s
output_filename = '%s-%.5d-of-%.5d' % (name, shard, num_shards)
output_file = os.path.join(FLAGS.output_directory, output_filename)
writer = tf.python_io.TFRecordWriter(output_file)
shard_counter = 0
files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
for i in files_in_shard:
filename = filenames[i]
label = labels[i]
synset = synsets[i]
human = humans[i]
bbox = bboxes[i]
image_buffer, height, width = _process_image(filename, coder)
example = _convert_to_example(filename, image_buffer, label,
synset, human, bbox,
height, width)
writer.write(example.SerializeToString())
shard_counter += 1
counter += 1
if not counter % 1000:
print('%s [thread %d]: Processed %d of %d images in thread batch.' %
(datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()
print('%s [thread %d]: Wrote %d images to %s' %
(datetime.now(), thread_index, shard_counter, output_file))
sys.stdout.flush()
shard_counter = 0
print('%s [thread %d]: Wrote %d images to %d shards.' %
(datetime.now(), thread_index, counter, num_files_in_thread))
sys.stdout.flush()
def _process_image_files(name, filenames, synsets, labels, humans,
bboxes, num_shards):
"""Process and save list of images as TFRecord of Example protos.
Args:
name: string, unique identifier specifying the data set
filenames: list of strings; each string is a path to an image file
synsets: list of strings; each string is a unique WordNet ID
labels: list of integer; each integer identifies the ground truth
humans: list of strings; each string is a human-readable label
bboxes: list of bounding boxes for each image. Note that each entry in this
list might contain from 0+ entries corresponding to the number of bounding
box annotations for the image.
num_shards: integer number of shards for this data set.
"""
assert len(filenames) == len(synsets)
assert len(filenames) == len(labels)
assert len(filenames) == len(humans)
assert len(filenames) == len(bboxes)
# Break all images into batches with a [ranges[i][0], ranges[i][1]].
spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int)
ranges = []
threads = []
for i in xrange(len(spacing) - 1):
ranges.append([spacing[i], spacing[i+1]])
# Launch a thread for each batch.
print('Launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges))
sys.stdout.flush()
# Create a mechanism for monitoring when all threads are finished.
coord = tf.train.Coordinator()
# Create a generic TensorFlow-based utility for converting all image codings.
coder = ImageCoder()
threads = []
for thread_index in xrange(len(ranges)):
args = (coder, thread_index, ranges, name, filenames,
synsets, labels, humans, bboxes, num_shards)
t = threading.Thread(target=_process_image_files_batch, args=args)
t.start()
threads.append(t)
# Wait for all the threads to terminate.
coord.join(threads)
print('%s: Finished writing all %d images in data set.' %
(datetime.now(), len(filenames)))
sys.stdout.flush()
def _find_image_files(data_dir, labels_file):
"""Build a list of all images files and labels in the data set.
Args:
data_dir: string, path to the root directory of images.
Assumes that the ImageNet data set resides in JPEG files located in
the following directory structure.
data_dir/n01440764/ILSVRC2012_val_00000293.JPEG
data_dir/n01440764/ILSVRC2012_val_00000543.JPEG
where 'n01440764' is the unique synset label associated with these images.
labels_file: string, path to the labels file.
The list of valid labels are held in this file. Assumes that the file
contains entries as such:
n01440764
n01443537
n01484850
where each line corresponds to a label expressed as a synset. We map
each synset contained in the file to an integer (based on the alphabetical
ordering) starting with the integer 1 corresponding to the synset
contained in the first line.
The reason we start the integer labels at 1 is to reserve label 0 as an
unused background class.
Returns:
filenames: list of strings; each string is a path to an image file.
synsets: list of strings; each string is a unique WordNet ID.
labels: list of integer; each integer identifies the ground truth.
"""
print('Determining list of input files and labels from %s.' % data_dir)
challenge_synsets = [l.strip() for l in
tf.gfile.FastGFile(labels_file, 'r').readlines()]
labels = []
filenames = []
synsets = []
# Leave label index 0 empty as a background class.
label_index = 1
# Construct the list of JPEG files and labels.
for synset in challenge_synsets:
jpeg_file_path = '%s/%s/*.JPEG' % (data_dir, synset)
matching_files = tf.gfile.Glob(jpeg_file_path)
labels.extend([label_index] * len(matching_files))
synsets.extend([synset] * len(matching_files))
filenames.extend(matching_files)
if not label_index % 100:
print('Finished finding files in %d of %d classes.' % (
label_index, len(challenge_synsets)))
label_index += 1
# Shuffle the ordering of all image files in order to guarantee
# random ordering of the images with respect to label in the
# saved TFRecord files. Make the randomization repeatable.
shuffled_index = range(len(filenames))
random.seed(12345)
random.shuffle(shuffled_index)
filenames = [filenames[i] for i in shuffled_index]
synsets = [synsets[i] for i in shuffled_index]
labels = [labels[i] for i in shuffled_index]
print('Found %d JPEG files across %d labels inside %s.' %
(len(filenames), len(challenge_synsets), data_dir))
return filenames, synsets, labels
def _find_human_readable_labels(synsets, synset_to_human):
"""Build a list of human-readable labels.
Args:
synsets: list of strings; each string is a unique WordNet ID.
synset_to_human: dict of synset to human labels, e.g.,
'n02119022' --> 'red fox, Vulpes vulpes'
Returns:
List of human-readable strings corresponding to each synset.
"""
humans = []
for s in synsets:
assert s in synset_to_human, ('Failed to find: %s' % s)
humans.append(synset_to_human[s])
return humans
def _find_image_bounding_boxes(filenames, image_to_bboxes):
"""Find the bounding boxes for a given image file.
Args:
filenames: list of strings; each string is a path to an image file.
image_to_bboxes: dictionary mapping image file names to a list of
bounding boxes. This list contains 0+ bounding boxes.
Returns:
List of bounding boxes for each image. Note that each entry in this
list might contain from 0+ entries corresponding to the number of bounding
box annotations for the image.
"""
num_image_bbox = 0
bboxes = []
for f in filenames:
basename = os.path.basename(f)
if basename in image_to_bboxes:
bboxes.append(image_to_bboxes[basename])
num_image_bbox += 1
else:
bboxes.append([])
print('Found %d images with bboxes out of %d images' % (
num_image_bbox, len(filenames)))
return bboxes
def _process_dataset(name, directory, num_shards, synset_to_human,
image_to_bboxes):
"""Process a complete data set and save it as a TFRecord.
Args:
name: string, unique identifier specifying the data set.
directory: string, root path to the data set.
num_shards: integer number of shards for this data set.
synset_to_human: dict of synset to human labels, e.g.,
'n02119022' --> 'red fox, Vulpes vulpes'
image_to_bboxes: dictionary mapping image file names to a list of
bounding boxes. This list contains 0+ bounding boxes.
"""
filenames, synsets, labels = _find_image_files(directory, FLAGS.labels_file)
humans = _find_human_readable_labels(synsets, synset_to_human)
bboxes = _find_image_bounding_boxes(filenames, image_to_bboxes)
_process_image_files(name, filenames, synsets, labels,
humans, bboxes, num_shards)
def _build_synset_lookup(imagenet_metadata_file):
"""Build lookup for synset to human-readable label.
Args:
imagenet_metadata_file: string, path to file containing mapping from
synset to human-readable label.
Assumes each line of the file looks like:
n02119247 black fox
n02119359 silver fox
n02119477 red fox, Vulpes fulva
where each line corresponds to a unique mapping. Note that each line is
formatted as <synset>\t<human readable label>.
Returns:
Dictionary of synset to human labels, such as:
'n02119022' --> 'red fox, Vulpes vulpes'
"""
lines = tf.gfile.FastGFile(imagenet_metadata_file, 'r').readlines()
synset_to_human = {}
for l in lines:
if l:
parts = l.strip().split('\t')
assert len(parts) == 2
synset = parts[0]
human = parts[1]
synset_to_human[synset] = human
return synset_to_human
def _build_bounding_box_lookup(bounding_box_file):
"""Build a lookup from image file to bounding boxes.
Args:
bounding_box_file: string, path to file with bounding boxes annotations.
Assumes each line of the file looks like:
n00007846_64193.JPEG,0.0060,0.2620,0.7545,0.9940
where each line corresponds to one bounding box annotation associated
with an image. Each line can be parsed as:
<JPEG file name>, <xmin>, <ymin>, <xmax>, <ymax>
Note that there might exist mulitple bounding box annotations associated
with an image file. This file is the output of process_bounding_boxes.py.
Returns:
Dictionary mapping image file names to a list of bounding boxes. This list
contains 0+ bounding boxes.
"""
lines = tf.gfile.FastGFile(bounding_box_file, 'r').readlines()
images_to_bboxes = {}
num_bbox = 0
num_image = 0
for l in lines:
if l:
parts = l.split(',')
assert len(parts) == 5, ('Failed to parse: %s' % l)
filename = parts[0]
xmin = float(parts[1])
ymin = float(parts[2])
xmax = float(parts[3])
ymax = float(parts[4])
box = [xmin, ymin, xmax, ymax]
if filename not in images_to_bboxes:
images_to_bboxes[filename] = []
num_image += 1
images_to_bboxes[filename].append(box)
num_bbox += 1
print('Successfully read %d bounding boxes '
'across %d images.' % (num_bbox, num_image))
return images_to_bboxes
def main(unused_argv):
assert not FLAGS.train_shards % FLAGS.num_threads, (
'Please make the FLAGS.num_threads commensurate with FLAGS.train_shards')
assert not FLAGS.validation_shards % FLAGS.num_threads, (
'Please make the FLAGS.num_threads commensurate with '
'FLAGS.validation_shards')
print('Saving results to %s' % FLAGS.output_directory)
# Build a map from synset to human-readable label.
synset_to_human = _build_synset_lookup(FLAGS.imagenet_metadata_file)
image_to_bboxes = _build_bounding_box_lookup(FLAGS.bounding_box_file)
# Run it!
_process_dataset('validation', FLAGS.validation_directory,
FLAGS.validation_shards, synset_to_human, image_to_bboxes)
_process_dataset('train', FLAGS.train_directory, FLAGS.train_shards,
synset_to_human, image_to_bboxes)
if __name__ == '__main__':
tf.app.run()
#!/bin/bash
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Script to download and preprocess the flowers data set. This data set
# provides a demonstration for how to perform fine-tuning (i.e. tranfer
# learning) from one model to a new data set.
#
# This script provides a demonstration for how to prepare an arbitrary
# data set for training an Inception v3 model.
#
# We demonstrate this with the flowers data set which consists of images
# of labeled flower images from 5 classes:
#
# daisy, dandelion, roses, sunflowers, tulips
#
# The final output of this script are sharded TFRecord files containing
# serialized Example protocol buffers. See build_image_data.py for
# details of how the Example protocol buffer contains image data.
#
# usage:
# ./download_and_preprocess_flowers.sh [data-dir]
set -e
if [ -z "$1" ]; then
echo "usage download_and_preprocess_flowers.sh [data dir]"
exit
fi
# Create the output and temporary directories.
DATA_DIR="${1%/}"
SCRATCH_DIR="${DATA_DIR}/raw-data/"
mkdir -p "${DATA_DIR}"
mkdir -p "${SCRATCH_DIR}"
WORK_DIR="$0.runfiles/inception"
# Download the flowers data.
DATA_URL="http://download.tensorflow.org/example_images/flower_photos.tgz"
CURRENT_DIR=$(pwd)
cd "${DATA_DIR}"
TARBALL="flower_photos.tgz"
if [ ! -f ${TARBALL} ]; then
echo "Downloading flower data set."
wget -O ${TARBALL} "${DATA_URL}"
else
echo "Skipping download of flower data."
fi
# Note the locations of the train and validation data.
TRAIN_DIRECTORY="${SCRATCH_DIR}train/"
VALIDATION_DIRECTORY="${SCRATCH_DIR}validation/"
# Expands the data into the flower_photos/ directory and rename it as the
# train directory.
tar xf flower_photos.tgz
rm -rf "${TRAIN_DIRECTORY}" "${VALIDATION_DIRECTORY}"
mv flower_photos "${TRAIN_DIRECTORY}"
# Generate a list of 5 labels: daisy, dandelion, roses, sunflowers, tulips
LABELS_FILE="${SCRATCH_DIR}/labels.txt"
ls -1 "${TRAIN_DIRECTORY}" | grep -v 'LICENSE' | sed 's/\///' | sort > "${LABELS_FILE}"
# Generate the validation data set.
while read LABEL; do
VALIDATION_DIR_FOR_LABEL="${VALIDATION_DIRECTORY}${LABEL}"
TRAIN_DIR_FOR_LABEL="${TRAIN_DIRECTORY}${LABEL}"
# Move the first randomly selected 100 images to the validation set.
mkdir -p "${VALIDATION_DIR_FOR_LABEL}"
VALIDATION_IMAGES=$(ls -1 "${TRAIN_DIR_FOR_LABEL}" | shuf | head -100)
for IMAGE in ${VALIDATION_IMAGES}; do
mv -f "${TRAIN_DIRECTORY}${LABEL}/${IMAGE}" "${VALIDATION_DIR_FOR_LABEL}"
done
done < "${LABELS_FILE}"
# Build the TFRecords version of the image data.
cd "${CURRENT_DIR}"
BUILD_SCRIPT="${WORK_DIR}/build_image_data"
OUTPUT_DIRECTORY="${DATA_DIR}"
"${BUILD_SCRIPT}" \
--train_directory="${TRAIN_DIRECTORY}" \
--validation_directory="${VALIDATION_DIRECTORY}" \
--output_directory="${OUTPUT_DIRECTORY}" \
--labels_file="${LABELS_FILE}"
#!/bin/bash
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Script to download and preprocess ImageNet Challenge 2012
# training and validation data set.
#
# The final output of this script are sharded TFRecord files containing
# serialized Example protocol buffers. See build_imagenet_data.py for
# details of how the Example protocol buffers contain the ImageNet data.
#
# The final output of this script appears as such:
#
# data_dir/train-00000-of-01024
# data_dir/train-00001-of-01024
# ...
# data_dir/train-00127-of-01024
#
# and
#
# data_dir/validation-00000-of-00128
# data_dir/validation-00001-of-00128
# ...
# data_dir/validation-00127-of-00128
#
# Note that this script may take several hours to run to completion. The
# conversion of the ImageNet data to TFRecords alone takes 2-3 hours depending
# on the speed of your machine. Please be patient.
#
# **IMPORTANT**
# To download the raw images, the user must create an account with image-net.org
# and generate a username and access_key. The latter two are required for
# downloading the raw images.
#
# usage:
# ./download_and_preprocess_imagenet.sh [data-dir]
set -e
if [ -z "$1" ]; then
echo "usage download_and_preprocess_imagenet.sh [data dir]"
exit
fi
# Create the output and temporary directories.
DATA_DIR="${1%/}"
SCRATCH_DIR="${DATA_DIR}/raw-data/"
mkdir -p "${DATA_DIR}"
mkdir -p "${SCRATCH_DIR}"
WORK_DIR="$0.runfiles/inception"
# Download the ImageNet data.
LABELS_FILE="${WORK_DIR}/data/imagenet_lsvrc_2015_synsets.txt"
DOWNLOAD_SCRIPT="${WORK_DIR}/data/download_imagenet.sh"
"${DOWNLOAD_SCRIPT}" "${SCRATCH_DIR}" "${LABELS_FILE}"
# Note the locations of the train and validation data.
TRAIN_DIRECTORY="${SCRATCH_DIR}train/"
VALIDATION_DIRECTORY="${SCRATCH_DIR}validation/"
# Preprocess the validation data by moving the images into the appropriate
# sub-directory based on the label (synset) of the image.
echo "Organizing the validation data into sub-directories."
PREPROCESS_VAL_SCRIPT="${WORK_DIR}/data/preprocess_imagenet_validation_data.py"
VAL_LABELS_FILE="${WORK_DIR}/data/imagenet_2012_validation_synset_labels.txt"
"${PREPROCESS_VAL_SCRIPT}" "${VALIDATION_DIRECTORY}" "${VAL_LABELS_FILE}"
# Convert the XML files for bounding box annotations into a single CSV.
echo "Extracting bounding box information from XML."
BOUNDING_BOX_SCRIPT="${WORK_DIR}/data/process_bounding_boxes.py"
BOUNDING_BOX_FILE="${SCRATCH_DIR}/imagenet_2012_bounding_boxes.csv"
BOUNDING_BOX_DIR="${SCRATCH_DIR}bounding_boxes/"
"${BOUNDING_BOX_SCRIPT}" "${BOUNDING_BOX_DIR}" "${LABELS_FILE}" \
| sort >"${BOUNDING_BOX_FILE}"
echo "Finished downloading and preprocessing the ImageNet data."
# Build the TFRecords version of the ImageNet data.
BUILD_SCRIPT="${WORK_DIR}/build_imagenet_data"
OUTPUT_DIRECTORY="${DATA_DIR}"
IMAGENET_METADATA_FILE="${WORK_DIR}/data/imagenet_metadata.txt"
"${BUILD_SCRIPT}" \
--train_directory="${TRAIN_DIRECTORY}" \
--validation_directory="${VALIDATION_DIRECTORY}" \
--output_directory="${OUTPUT_DIRECTORY}" \
--imagenet_metadata_file="${IMAGENET_METADATA_FILE}" \
--labels_file="${LABELS_FILE}" \
--bounding_box_file="${BOUNDING_BOX_FILE}"
#!/bin/bash
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Script to download ImageNet Challenge 2012 training and validation data set.
#
# Downloads and decompresses raw images and bounding boxes.
#
# **IMPORTANT**
# To download the raw images, the user must create an account with image-net.org
# and generate a username and access_key. The latter two are required for
# downloading the raw images.
#
# usage:
# ./download_imagenet.sh [dirname]
set -e
if [ "x$IMAGENET_ACCESS_KEY" == x -o "x$IMAGENET_USERNAME" == x ]; then
cat <<END
In order to download the imagenet data, you have to create an account with
image-net.org. This will get you a username and an access key. You can set the
IMAGENET_USERNAME and IMAGENET_ACCESS_KEY environment variables, or you can
enter the credentials here.
END
read -p "Username: " IMAGENET_USERNAME
read -p "Access key: " IMAGENET_ACCESS_KEY
fi
OUTDIR="${1:-./imagenet-data}"
SYNSETS_FILE="${2:-./synsets.txt}"
SYNSETS_FILE="${PWD}/${SYNSETS_FILE}"
echo "Saving downloaded files to $OUTDIR"
mkdir -p "${OUTDIR}"
CURRENT_DIR=$(pwd)
BBOX_DIR="${OUTDIR}bounding_boxes"
mkdir -p "${BBOX_DIR}"
cd "${OUTDIR}"
# Download and process all of the ImageNet bounding boxes.
BASE_URL="http://www.image-net.org/challenges/LSVRC/2012/nonpub"
# See here for details: http://www.image-net.org/download-bboxes
BOUNDING_BOX_ANNOTATIONS="${BASE_URL}/ILSVRC2012_bbox_train_v2.tar.gz"
BBOX_TAR_BALL="${BBOX_DIR}/annotations.tar.gz"
echo "Downloading bounding box annotations."
wget "${BOUNDING_BOX_ANNOTATIONS}" -O "${BBOX_TAR_BALL}"
echo "Uncompressing bounding box annotations ..."
tar xzf "${BBOX_TAR_BALL}" -C "${BBOX_DIR}"
LABELS_ANNOTATED="${BBOX_DIR}/*"
NUM_XML=$(ls -1 ${LABELS_ANNOTATED} | wc -l)
echo "Identified ${NUM_XML} bounding box annotations."
# Download and uncompress all images from the ImageNet 2012 validation dataset.
VALIDATION_TARBALL="ILSVRC2012_img_val.tar"
OUTPUT_PATH="${OUTDIR}validation/"
mkdir -p "${OUTPUT_PATH}"
cd "${OUTDIR}/.."
echo "Downloading ${VALIDATION_TARBALL} to ${OUTPUT_PATH}."
wget -nd -c "${BASE_URL}/${VALIDATION_TARBALL}"
tar xf "${VALIDATION_TARBALL}" -C "${OUTPUT_PATH}"
# Download all images from the ImageNet 2012 train dataset.
TRAIN_TARBALL="ILSVRC2012_img_train.tar"
OUTPUT_PATH="${OUTDIR}train/"
mkdir -p "${OUTPUT_PATH}"
cd "${OUTDIR}/.."
echo "Downloading ${TRAIN_TARBALL} to ${OUTPUT_PATH}."
wget -nd -c "${BASE_URL}/${TRAIN_TARBALL}"
# Un-compress the individual tar-files within the train tar-file.
echo "Uncompressing individual train tar-balls in the training data."
while read SYNSET; do
echo "Processing: ${SYNSET}"
# Create a directory and delete anything there.
mkdir -p "${OUTPUT_PATH}/${SYNSET}"
rm -rf "${OUTPUT_PATH}/${SYNSET}/*"
# Uncompress into the directory.
tar xf "${TRAIN_TARBALL}" "${SYNSET}.tar"
tar xf "${SYNSET}.tar" -C "${OUTPUT_PATH}/${SYNSET}/"
rm -f "${SYNSET}.tar"
echo "Finished processing: ${SYNSET}"
done < "${SYNSETS_FILE}"
This source diff could not be displayed because it is too large. You can view the blob instead.
n01440764
n01443537
n01484850
n01491361
n01494475
n01496331
n01498041
n01514668
n01514859
n01518878
n01530575
n01531178
n01532829
n01534433
n01537544
n01558993
n01560419
n01580077
n01582220
n01592084
n01601694
n01608432
n01614925
n01616318
n01622779
n01629819
n01630670
n01631663
n01632458
n01632777
n01641577
n01644373
n01644900
n01664065
n01665541
n01667114
n01667778
n01669191
n01675722
n01677366
n01682714
n01685808
n01687978
n01688243
n01689811
n01692333
n01693334
n01694178
n01695060
n01697457
n01698640
n01704323
n01728572
n01728920
n01729322
n01729977
n01734418
n01735189
n01737021
n01739381
n01740131
n01742172
n01744401
n01748264
n01749939
n01751748
n01753488
n01755581
n01756291
n01768244
n01770081
n01770393
n01773157
n01773549
n01773797
n01774384
n01774750
n01775062
n01776313
n01784675
n01795545
n01796340
n01797886
n01798484
n01806143
n01806567
n01807496
n01817953
n01818515
n01819313
n01820546
n01824575
n01828970
n01829413
n01833805
n01843065
n01843383
n01847000
n01855032
n01855672
n01860187
n01871265
n01872401
n01873310
n01877812
n01882714
n01883070
n01910747
n01914609
n01917289
n01924916
n01930112
n01943899
n01944390
n01945685
n01950731
n01955084
n01968897
n01978287
n01978455
n01980166
n01981276
n01983481
n01984695
n01985128
n01986214
n01990800
n02002556
n02002724
n02006656
n02007558
n02009229
n02009912
n02011460
n02012849
n02013706
n02017213
n02018207
n02018795
n02025239
n02027492
n02028035
n02033041
n02037110
n02051845
n02056570
n02058221
n02066245
n02071294
n02074367
n02077923
n02085620
n02085782
n02085936
n02086079
n02086240
n02086646
n02086910
n02087046
n02087394
n02088094
n02088238
n02088364
n02088466
n02088632
n02089078
n02089867
n02089973
n02090379
n02090622
n02090721
n02091032
n02091134
n02091244
n02091467
n02091635
n02091831
n02092002
n02092339
n02093256
n02093428
n02093647
n02093754
n02093859
n02093991
n02094114
n02094258
n02094433
n02095314
n02095570
n02095889
n02096051
n02096177
n02096294
n02096437
n02096585
n02097047
n02097130
n02097209
n02097298
n02097474
n02097658
n02098105
n02098286
n02098413
n02099267
n02099429
n02099601
n02099712
n02099849
n02100236
n02100583
n02100735
n02100877
n02101006
n02101388
n02101556
n02102040
n02102177
n02102318
n02102480
n02102973
n02104029
n02104365
n02105056
n02105162
n02105251
n02105412
n02105505
n02105641
n02105855
n02106030
n02106166
n02106382
n02106550
n02106662
n02107142
n02107312
n02107574
n02107683
n02107908
n02108000
n02108089
n02108422
n02108551
n02108915
n02109047
n02109525
n02109961
n02110063
n02110185
n02110341
n02110627
n02110806
n02110958
n02111129
n02111277
n02111500
n02111889
n02112018
n02112137
n02112350
n02112706
n02113023
n02113186
n02113624
n02113712
n02113799
n02113978
n02114367
n02114548
n02114712
n02114855
n02115641
n02115913
n02116738
n02117135
n02119022
n02119789
n02120079
n02120505
n02123045
n02123159
n02123394
n02123597
n02124075
n02125311
n02127052
n02128385
n02128757
n02128925
n02129165
n02129604
n02130308
n02132136
n02133161
n02134084
n02134418
n02137549
n02138441
n02165105
n02165456
n02167151
n02168699
n02169497
n02172182
n02174001
n02177972
n02190166
n02206856
n02219486
n02226429
n02229544
n02231487
n02233338
n02236044
n02256656
n02259212
n02264363
n02268443
n02268853
n02276258
n02277742
n02279972
n02280649
n02281406
n02281787
n02317335
n02319095
n02321529
n02325366
n02326432
n02328150
n02342885
n02346627
n02356798
n02361337
n02363005
n02364673
n02389026
n02391049
n02395406
n02396427
n02397096
n02398521
n02403003
n02408429
n02410509
n02412080
n02415577
n02417914
n02422106
n02422699
n02423022
n02437312
n02437616
n02441942
n02442845
n02443114
n02443484
n02444819
n02445715
n02447366
n02454379
n02457408
n02480495
n02480855
n02481823
n02483362
n02483708
n02484975
n02486261
n02486410
n02487347
n02488291
n02488702
n02489166
n02490219
n02492035
n02492660
n02493509
n02493793
n02494079
n02497673
n02500267
n02504013
n02504458
n02509815
n02510455
n02514041
n02526121
n02536864
n02606052
n02607072
n02640242
n02641379
n02643566
n02655020
n02666196
n02667093
n02669723
n02672831
n02676566
n02687172
n02690373
n02692877
n02699494
n02701002
n02704792
n02708093
n02727426
n02730930
n02747177
n02749479
n02769748
n02776631
n02777292
n02782093
n02783161
n02786058
n02787622
n02788148
n02790996
n02791124
n02791270
n02793495
n02794156
n02795169
n02797295
n02799071
n02802426
n02804414
n02804610
n02807133
n02808304
n02808440
n02814533
n02814860
n02815834
n02817516
n02823428
n02823750
n02825657
n02834397
n02835271
n02837789
n02840245
n02841315
n02843684
n02859443
n02860847
n02865351
n02869837
n02870880
n02871525
n02877765
n02879718
n02883205
n02892201
n02892767
n02894605
n02895154
n02906734
n02909870
n02910353
n02916936
n02917067
n02927161
n02930766
n02939185
n02948072
n02950826
n02951358
n02951585
n02963159
n02965783
n02966193
n02966687
n02971356
n02974003
n02977058
n02978881
n02979186
n02980441
n02981792
n02988304
n02992211
n02992529
n02999410
n03000134
n03000247
n03000684
n03014705
n03016953
n03017168
n03018349
n03026506
n03028079
n03032252
n03041632
n03042490
n03045698
n03047690
n03062245
n03063599
n03063689
n03065424
n03075370
n03085013
n03089624
n03095699
n03100240
n03109150
n03110669
n03124043
n03124170
n03125729
n03126707
n03127747
n03127925
n03131574
n03133878
n03134739
n03141823
n03146219
n03160309
n03179701
n03180011
n03187595
n03188531
n03196217
n03197337
n03201208
n03207743
n03207941
n03208938
n03216828
n03218198
n03220513
n03223299
n03240683
n03249569
n03250847
n03255030
n03259280
n03271574
n03272010
n03272562
n03290653
n03291819
n03297495
n03314780
n03325584
n03337140
n03344393
n03345487
n03347037
n03355925
n03372029
n03376595
n03379051
n03384352
n03388043
n03388183
n03388549
n03393912
n03394916
n03400231
n03404251
n03417042
n03424325
n03425413
n03443371
n03444034
n03445777
n03445924
n03447447
n03447721
n03450230
n03452741
n03457902
n03459775
n03461385
n03467068
n03476684
n03476991
n03478589
n03481172
n03482405
n03483316
n03485407
n03485794
n03492542
n03494278
n03495258
n03496892
n03498962
n03527444
n03529860
n03530642
n03532672
n03534580
n03535780
n03538406
n03544143
n03584254
n03584829
n03590841
n03594734
n03594945
n03595614
n03598930
n03599486
n03602883
n03617480
n03623198
n03627232
n03630383
n03633091
n03637318
n03642806
n03649909
n03657121
n03658185
n03661043
n03662601
n03666591
n03670208
n03673027
n03676483
n03680355
n03690938
n03691459
n03692522
n03697007
n03706229
n03709823
n03710193
n03710637
n03710721
n03717622
n03720891
n03721384
n03724870
n03729826
n03733131
n03733281
n03733805
n03742115
n03743016
n03759954
n03761084
n03763968
n03764736
n03769881
n03770439
n03770679
n03773504
n03775071
n03775546
n03776460
n03777568
n03777754
n03781244
n03782006
n03785016
n03786901
n03787032
n03788195
n03788365
n03791053
n03792782
n03792972
n03793489
n03794056
n03796401
n03803284
n03804744
n03814639
n03814906
n03825788
n03832673
n03837869
n03838899
n03840681
n03841143
n03843555
n03854065
n03857828
n03866082
n03868242
n03868863
n03871628
n03873416
n03874293
n03874599
n03876231
n03877472
n03877845
n03884397
n03887697
n03888257
n03888605
n03891251
n03891332
n03895866
n03899768
n03902125
n03903868
n03908618
n03908714
n03916031
n03920288
n03924679
n03929660
n03929855
n03930313
n03930630
n03933933
n03935335
n03937543
n03938244
n03942813
n03944341
n03947888
n03950228
n03954731
n03956157
n03958227
n03961711
n03967562
n03970156
n03976467
n03976657
n03977966
n03980874
n03982430
n03983396
n03991062
n03992509
n03995372
n03998194
n04004767
n04005630
n04008634
n04009552
n04019541
n04023962
n04026417
n04033901
n04033995
n04037443
n04039381
n04040759
n04041544
n04044716
n04049303
n04065272
n04067472
n04069434
n04070727
n04074963
n04081281
n04086273
n04090263
n04099969
n04111531
n04116512
n04118538
n04118776
n04120489
n04125021
n04127249
n04131690
n04133789
n04136333
n04141076
n04141327
n04141975
n04146614
n04147183
n04149813
n04152593
n04153751
n04154565
n04162706
n04179913
n04192698
n04200800
n04201297
n04204238
n04204347
n04208210
n04209133
n04209239
n04228054
n04229816
n04235860
n04238763
n04239074
n04243546
n04251144
n04252077
n04252225
n04254120
n04254680
n04254777
n04258138
n04259630
n04263257
n04264628
n04265275
n04266014
n04270147
n04273569
n04275548
n04277352
n04285008
n04286575
n04296562
n04310018
n04311004
n04311174
n04317175
n04325704
n04326547
n04328186
n04330267
n04332243
n04335435
n04336792
n04344873
n04346328
n04347754
n04350905
n04355338
n04355933
n04356056
n04357314
n04366367
n04367480
n04370456
n04371430
n04371774
n04372370
n04376876
n04380533
n04389033
n04392985
n04398044
n04399382
n04404412
n04409515
n04417672
n04418357
n04423845
n04428191
n04429376
n04435653
n04442312
n04443257
n04447861
n04456115
n04458633
n04461696
n04462240
n04465501
n04467665
n04476259
n04479046
n04482393
n04483307
n04485082
n04486054
n04487081
n04487394
n04493381
n04501370
n04505470
n04507155
n04509417
n04515003
n04517823
n04522168
n04523525
n04525038
n04525305
n04532106
n04532670
n04536866
n04540053
n04542943
n04548280
n04548362
n04550184
n04552348
n04553703
n04554684
n04557648
n04560804
n04562935
n04579145
n04579432
n04584207
n04589890
n04590129
n04591157
n04591713
n04592741
n04596742
n04597913
n04599235
n04604644
n04606251
n04612504
n04613696
n06359193
n06596364
n06785654
n06794110
n06874185
n07248320
n07565083
n07579787
n07583066
n07584110
n07590611
n07613480
n07614500
n07615774
n07684084
n07693725
n07695742
n07697313
n07697537
n07711569
n07714571
n07714990
n07715103
n07716358
n07716906
n07717410
n07717556
n07718472
n07718747
n07720875
n07730033
n07734744
n07742313
n07745940
n07747607
n07749582
n07753113
n07753275
n07753592
n07754684
n07760859
n07768694
n07802026
n07831146
n07836838
n07860988
n07871810
n07873807
n07875152
n07880968
n07892512
n07920052
n07930864
n07932039
n09193705
n09229709
n09246464
n09256479
n09288635
n09332890
n09399592
n09421951
n09428293
n09468604
n09472597
n09835506
n10148035
n10565667
n11879895
n11939491
n12057211
n12144580
n12267677
n12620546
n12768682
n12985857
n12998815
n13037406
n13040303
n13044778
n13052670
n13054560
n13133613
n15075141
This source diff could not be displayed because it is too large. You can view the blob instead.
#!/usr/bin/python
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Process the ImageNet Challenge bounding boxes for TensorFlow model training.
Associate the ImageNet 2012 Challenge validation data set with labels.
The raw ImageNet validation data set is expected to reside in JPEG files
located in the following directory structure.
data_dir/ILSVRC2012_val_00000001.JPEG
data_dir/ILSVRC2012_val_00000002.JPEG
...
data_dir/ILSVRC2012_val_00050000.JPEG
This script moves the files into a directory structure like such:
data_dir/n01440764/ILSVRC2012_val_00000293.JPEG
data_dir/n01440764/ILSVRC2012_val_00000543.JPEG
...
where 'n01440764' is the unique synset label associated with
these images.
This directory reorganization requires a mapping from validation image
number (i.e. suffix of the original file) to the associated label. This
is provided in the ImageNet development kit via a Matlab file.
In order to make life easier and divorce ourselves from Matlab, we instead
supply a custom text file that provides this mapping for us.
Sample usage:
./preprocess_imagenet_validation_data.py ILSVRC2012_img_val \
imagenet_2012_validation_synset_labels.txt
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path
import sys
if __name__ == '__main__':
if len(sys.argv) < 3:
print('Invalid usage\n'
'usage: preprocess_imagenet_validation_data.py '
'<validation data dir> <validation labels file>')
sys.exit(-1)
data_dir = sys.argv[1]
validation_labels_file = sys.argv[2]
# Read in the 50000 synsets associated with the validation data set.
labels = [l.strip() for l in open(validation_labels_file).readlines()]
unique_labels = set(labels)
# Make all sub-directories in the validation data dir.
for label in unique_labels:
labeled_data_dir = os.path.join(data_dir, label)
os.makedirs(labeled_data_dir)
# Move all of the image to the appropriate sub-directory.
for i in xrange(len(labels)):
basename = 'ILSVRC2012_val_000%.5d.JPEG' % (i + 1)
original_filename = os.path.join(data_dir, basename)
if not os.path.exists(original_filename):
print('Failed to find: ' % original_filename)
sys.exit(-1)
new_filename = os.path.join(data_dir, labels[i], basename)
os.rename(original_filename, new_filename)
#!/usr/bin/python
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Process the ImageNet Challenge bounding boxes for TensorFlow model training.
This script is called as
process_bounding_boxes.py <dir> [synsets-file]
Where <dir> is a directory containing the downloaded and unpacked bounding box
data. If [synsets-file] is supplied, then only the bounding boxes whose
synstes are contained within this file are returned. Note that the
[synsets-file] file contains synset ids, one per line.
The script dumps out a CSV text file in which each line contains an entry.
n00007846_64193.JPEG,0.0060,0.2620,0.7545,0.9940
The entry can be read as:
<JPEG file name>, <xmin>, <ymin>, <xmax>, <ymax>
The bounding box for <JPEG file name> contains two points (xmin, ymin) and
(xmax, ymax) specifying the lower-left corner and upper-right corner of a
bounding box in *relative* coordinates.
The user supplies a directory where the XML files reside. The directory
structure in the directory <dir> is assumed to look like this:
<dir>/nXXXXXXXX/nXXXXXXXX_YYYY.xml
Each XML file contains a bounding box annotation. The script:
(1) Parses the XML file and extracts the filename, label and bounding box info.
(2) The bounding box is specified in the XML files as integer (xmin, ymin) and
(xmax, ymax) *relative* to image size displayed to the human annotator. The
size of the image displayed to the human annotator is stored in the XML file
as integer (height, width).
Note that the displayed size will differ from the actual size of the image
downloaded from image-net.org. To make the bounding box annotation useable,
we convert bounding box to floating point numbers relative to displayed
height and width of the image.
Note that each XML file might contain N bounding box annotations.
Note that the points are all clamped at a range of [0.0, 1.0] because some
human annotations extend outside the range of the supplied image.
See details here: http://image-net.org/download-bboxes
(3) By default, the script outputs all valid bounding boxes. If a
[synsets-file] is supplied, only the subset of bounding boxes associated
with those synsets are outputted. Importantly, one can supply a list of
synsets in the ImageNet Challenge and output the list of bounding boxes
associated with the training images of the ILSVRC.
We use these bounding boxes to inform the random distortion of images
supplied to the network.
If you run this script successfully, you will see the following output
to stderr:
> Finished processing 544546 XML files.
> Skipped 0 XML files not in ImageNet Challenge.
> Skipped 0 bounding boxes not in ImageNet Challenge.
> Wrote 615299 bounding boxes from 544546 annotated images.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os.path
import sys
import xml.etree.ElementTree as ET
class BoundingBox(object):
pass
def GetItem(name, root, index=0):
count = 0
for item in root.iter(name):
if count == index:
return item.text
count += 1
# Failed to find "index" occurrence of item.
return -1
def GetInt(name, root, index=0):
return int(GetItem(name, root, index))
def FindNumberBoundingBoxes(root):
index = 0
while True:
if GetInt('xmin', root, index) == -1:
break
index += 1
return index
def ProcessXMLAnnotation(xml_file):
"""Process a single XML file containing a bounding box."""
# pylint: disable=broad-except
try:
tree = ET.parse(xml_file)
except Exception:
print('Failed to parse: ' + xml_file, file=sys.stderr)
return None
# pylint: enable=broad-except
root = tree.getroot()
num_boxes = FindNumberBoundingBoxes(root)
boxes = []
for index in xrange(num_boxes):
box = BoundingBox()
# Grab the 'index' annotation.
box.xmin = GetInt('xmin', root, index)
box.ymin = GetInt('ymin', root, index)
box.xmax = GetInt('xmax', root, index)
box.ymax = GetInt('ymax', root, index)
box.width = GetInt('width', root)
box.height = GetInt('height', root)
box.filename = GetItem('filename', root) + '.JPEG'
box.label = GetItem('name', root)
xmin = float(box.xmin) / float(box.width)
xmax = float(box.xmax) / float(box.width)
ymin = float(box.ymin) / float(box.height)
ymax = float(box.ymax) / float(box.height)
# Some images contain bounding box annotations that
# extend outside of the supplied image. See, e.g.
# n03127925/n03127925_147.xml
# Additionally, for some bounding boxes, the min > max
# or the box is entirely outside of the image.
min_x = min(xmin, xmax)
max_x = max(xmin, xmax)
box.xmin_scaled = min(max(min_x, 0.0), 1.0)
box.xmax_scaled = min(max(max_x, 0.0), 1.0)
min_y = min(ymin, ymax)
max_y = max(ymin, ymax)
box.ymin_scaled = min(max(min_y, 0.0), 1.0)
box.ymax_scaled = min(max(max_y, 0.0), 1.0)
boxes.append(box)
return boxes
if __name__ == '__main__':
if len(sys.argv) < 2 or len(sys.argv) > 3:
print('Invalid usage\n'
'usage: process_bounding_boxes.py <dir> [synsets-file]',
file=sys.stderr)
sys.exit(-1)
xml_files = glob.glob(sys.argv[1] + '/*/*.xml')
print('Identified %d XML files in %s' % (len(xml_files), sys.argv[1]),
file=sys.stderr)
if len(sys.argv) == 3:
labels = set([l.strip() for l in open(sys.argv[2]).readlines()])
print('Identified %d synset IDs in %s' % (len(labels), sys.argv[2]),
file=sys.stderr)
else:
labels = None
skipped_boxes = 0
skipped_files = 0
saved_boxes = 0
saved_files = 0
for file_index, one_file in enumerate(xml_files):
# Example: <...>/n06470073/n00141669_6790.xml
label = os.path.basename(os.path.dirname(one_file))
# Determine if the annotation is from an ImageNet Challenge label.
if labels is not None and label not in labels:
skipped_files += 1
continue
bboxes = ProcessXMLAnnotation(one_file)
assert bboxes is not None, 'No bounding boxes found in ' + one_file
found_box = False
for bbox in bboxes:
if labels is not None:
if bbox.label != label:
# Note: There is a slight bug in the bounding box annotation data.
# Many of the dog labels have the human label 'Scottish_deerhound'
# instead of the synset ID 'n02092002' in the bbox.label field. As a
# simple hack to overcome this issue, we only exclude bbox labels
# *which are synset ID's* that do not match original synset label for
# the XML file.
if bbox.label in labels:
skipped_boxes += 1
continue
# Guard against improperly specified boxes.
if (bbox.xmin_scaled >= bbox.xmax_scaled or
bbox.ymin_scaled >= bbox.ymax_scaled):
skipped_boxes += 1
continue
# Note bbox.filename occasionally contains '%s' in the name. This is
# data set noise that is fixed by just using the basename of the XML file.
image_filename = os.path.splitext(os.path.basename(one_file))[0]
print('%s.JPEG,%.4f,%.4f,%.4f,%.4f' %
(image_filename,
bbox.xmin_scaled, bbox.ymin_scaled,
bbox.xmax_scaled, bbox.ymax_scaled))
saved_boxes += 1
found_box = True
if found_box:
saved_files += 1
else:
skipped_files += 1
if not file_index % 5000:
print('--> processed %d of %d XML files.' %
(file_index + 1, len(xml_files)),
file=sys.stderr)
print('--> skipped %d boxes and %d XML files.' %
(skipped_boxes, skipped_files), file=sys.stderr)
print('Finished processing %d XML files.' % len(xml_files), file=sys.stderr)
print('Skipped %d XML files not in ImageNet Challenge.' % skipped_files,
file=sys.stderr)
print('Skipped %d bounding boxes not in ImageNet Challenge.' % skipped_boxes,
file=sys.stderr)
print('Wrote %d bounding boxes from %d annotated images.' %
(saved_boxes, saved_files),
file=sys.stderr)
print('Finished.', file=sys.stderr)
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Small library that points to a data set.
Methods of Data class:
data_files: Returns a python list of all (sharded) data set files.
num_examples_per_epoch: Returns the number of examples in the data set.
num_classes: Returns the number of classes in the data set.
reader: Return a reader for a single entry from the data set.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from abc import ABCMeta
from abc import abstractmethod
import os
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
# Basic model parameters.
tf.app.flags.DEFINE_string('data_dir', '/tmp/mydata',
"""Path to the processed data, i.e. """
"""TFRecord of Example protos.""")
class Dataset(object):
"""A simple class for handling data sets."""
__metaclass__ = ABCMeta
def __init__(self, name, subset):
"""Initialize dataset using a subset and the path to the data."""
assert subset in self.available_subsets(), self.available_subsets()
self.name = name
self.subset = subset
@abstractmethod
def num_classes(self):
"""Returns the number of classes in the data set."""
pass
# return 10
@abstractmethod
def num_examples_per_epoch(self):
"""Returns the number of examples in the data subset."""
pass
# if self.subset == 'train':
# return 10000
# if self.subset == 'validation':
# return 1000
@abstractmethod
def download_message(self):
"""Prints a download message for the Dataset."""
pass
def available_subsets(self):
"""Returns the list of available subsets."""
return ['train', 'validation']
def data_files(self):
"""Returns a python list of all (sharded) data subset files.
Returns:
python list of all (sharded) data set files.
Raises:
ValueError: if there are not data_files matching the subset.
"""
tf_record_pattern = os.path.join(FLAGS.data_dir, '%s-*' % self.subset)
data_files = tf.gfile.Glob(tf_record_pattern)
if not data_files:
print('No files found for dataset %s/%s at %s' % (self.name,
self.subset,
FLAGS.data_dir))
self.download_message()
exit(-1)
return data_files
def reader(self):
"""Return a reader for a single entry from the data set.
See io_ops.py for details of Reader class.
Returns:
Reader object that reads the data set.
"""
return tf.TFRecordReader()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Small library that points to the flowers data set.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from inception.dataset import Dataset
class FlowersData(Dataset):
"""Flowers data set."""
def __init__(self, subset):
super(FlowersData, self).__init__('Flowers', subset)
def num_classes(self):
"""Returns the number of classes in the data set."""
return 5
def num_examples_per_epoch(self):
"""Returns the number of examples in the data subset."""
if self.subset == 'train':
return 3170
if self.subset == 'validation':
return 500
def download_message(self):
"""Instruction to download and extract the tarball from Flowers website."""
print('Failed to find any Flowers %s files'% self.subset)
print('')
print('If you have already downloaded and processed the data, then make '
'sure to set --data_dir to point to the directory containing the '
'location of the sharded TFRecords.\n')
print('Please see README.md for instructions on how to build '
'the flowers dataset using download_and_preprocess_flowers.\n')
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A binary to evaluate Inception on the flowers data set.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from inception import inception_eval
from inception.flowers_data import FlowersData
FLAGS = tf.app.flags.FLAGS
def main(unused_argv=None):
dataset = FlowersData(subset=FLAGS.subset)
assert dataset.data_files()
if tf.gfile.Exists(FLAGS.eval_dir):
tf.gfile.DeleteRecursively(FLAGS.eval_dir)
tf.gfile.MakeDirs(FLAGS.eval_dir)
inception_eval.evaluate(dataset)
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A binary to train Inception on the flowers data set.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from inception import inception_train
from inception.flowers_data import FlowersData
FLAGS = tf.app.flags.FLAGS
def main(_):
dataset = FlowersData(subset=FLAGS.subset)
assert dataset.data_files()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
inception_train.train(dataset)
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Read and preprocess image data.
Image processing occurs on a single image at a time. Image are read and
preprocessed in pararllel across mulitple threads. The resulting images
are concatenated together to form a single batch for training or evaluation.
-- Provide processed image data for a network:
inputs: Construct batches of evaluation examples of images.
distorted_inputs: Construct batches of training examples of images.
batch_inputs: Construct batches of training or evaluation examples of images.
-- Data processing:
parse_example_proto: Parses an Example proto containing a training example
of an image.
-- Image decoding:
decode_jpeg: Decode a JPEG encoded string into a 3-D float32 Tensor.
-- Image preprocessing:
image_preprocessing: Decode and preprocess one image for evaluation or training
distort_image: Distort one image for training a network.
eval_image: Prepare one image for evaluation.
distort_color: Distort the color in one image for training.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('batch_size', 32,
"""Number of images to process in a batch.""")
tf.app.flags.DEFINE_integer('image_size', 299,
"""Provide square images of this size.""")
tf.app.flags.DEFINE_integer('num_preprocess_threads', 4,
"""Number of preprocessing threads per tower. """
"""Please make this a multiple of 4.""")
# Images are preprocessed asynchronously using multiple threads specifed by
# --num_preprocss_threads and the resulting processed images are stored in a
# random shuffling queue. The shuffling queue dequeues --batch_size images
# for processing on a given Inception tower. A larger shuffling queue guarantees
# better mixing across examples within a batch and results in slightly higher
# predictive performance in a trained model. Empirically,
# --input_queue_memory_factor=16 works well. A value of 16 implies a queue size
# of 1024*16 images. Assuming RGB 299x299 images, this implies a queue size of
# 16GB. If the machine is memory limited, then decrease this factor to
# decrease the CPU memory footprint, accordingly.
tf.app.flags.DEFINE_integer('input_queue_memory_factor', 16,
"""Size of the queue of preprocessed images. """
"""Default is ideal but try smaller values, e.g. """
"""4, 2 or 1, if host memory is constrained. See """
"""comments in code for more details.""")
def inputs(dataset, batch_size=None, num_preprocess_threads=None):
"""Generate batches of ImageNet images for evaluation.
Use this function as the inputs for evaluating a network.
Note that some (minimal) image preprocessing occurs during evaluation
including central cropping and resizing of the image to fit the network.
Args:
dataset: instance of Dataset class specifying the dataset.
batch_size: integer, number of examples in batch
num_preprocess_threads: integer, total number of preprocessing threads but
None defaults to FLAGS.num_preprocess_threads.
Returns:
images: Images. 4D tensor of size [batch_size, FLAGS.image_size,
image_size, 3].
labels: 1-D integer Tensor of [FLAGS.batch_size].
"""
if not batch_size:
batch_size = FLAGS.batch_size
# Force all input processing onto CPU in order to reserve the GPU for
# the forward inference and back-propagation.
with tf.device('/cpu:0'):
images, labels = batch_inputs(
dataset, batch_size, train=False,
num_preprocess_threads=num_preprocess_threads)
return images, labels
def distorted_inputs(dataset, batch_size=None, num_preprocess_threads=None):
"""Generate batches of distorted versions of ImageNet images.
Use this function as the inputs for training a network.
Distorting images provides a useful technique for augmenting the data
set during training in order to make the network invariant to aspects
of the image that do not effect the label.
Args:
dataset: instance of Dataset class specifying the dataset.
batch_size: integer, number of examples in batch
num_preprocess_threads: integer, total number of preprocessing threads but
None defaults to FLAGS.num_preprocess_threads.
Returns:
images: Images. 4D tensor of size [batch_size, FLAGS.image_size,
FLAGS.image_size, 3].
labels: 1-D integer Tensor of [batch_size].
"""
if not batch_size:
batch_size = FLAGS.batch_size
# Force all input processing onto CPU in order to reserve the GPU for
# the forward inference and back-propagation.
with tf.device('/cpu:0'):
images, labels = batch_inputs(
dataset, batch_size, train=True,
num_preprocess_threads=num_preprocess_threads)
return images, labels
def decode_jpeg(image_buffer, scope=None):
"""Decode a JPEG string into one 3-D float image Tensor.
Args:
image_buffer: scalar string Tensor.
scope: Optional scope for op_scope.
Returns:
3-D float Tensor with values ranging from [0, 1).
"""
with tf.op_scope([image_buffer], scope, 'decode_jpeg'):
# Decode the string as an RGB JPEG.
# Note that the resulting image contains an unknown height and width
# that is set dynamically by decode_jpeg. In other words, the height
# and width of image is unknown at compile-time.
image = tf.image.decode_jpeg(image_buffer, channels=3)
# After this point, all image pixels reside in [0,1)
# until the very end, when they're rescaled to (-1, 1). The various
# adjust_* ops all require this range for dtype float.
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
return image
def distort_color(image, thread_id=0, scope=None):
"""Distort the color of the image.
Each color distortion is non-commutative and thus ordering of the color ops
matters. Ideally we would randomly permute the ordering of the color ops.
Rather then adding that level of complication, we select a distinct ordering
of color ops for each preprocessing thread.
Args:
image: Tensor containing single image.
thread_id: preprocessing thread ID.
scope: Optional scope for op_scope.
Returns:
color-distorted image
"""
with tf.op_scope([image], scope, 'distort_color'):
color_ordering = thread_id % 2
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
elif color_ordering == 1:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
# The random_* ops do not necessarily clamp.
image = tf.clip_by_value(image, 0.0, 1.0)
return image
def distort_image(image, height, width, bbox, thread_id=0, scope=None):
"""Distort one image for training a network.
Distorting images provides a useful technique for augmenting the data
set during training in order to make the network invariant to aspects
of the image that do not effect the label.
Args:
image: 3-D float Tensor of image
height: integer
width: integer
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged
as [ymin, xmin, ymax, xmax].
thread_id: integer indicating the preprocessing thread.
scope: Optional scope for op_scope.
Returns:
3-D float Tensor of distorted image used for training.
"""
with tf.op_scope([image, height, width, bbox], scope, 'distort_image'):
# Each bounding box has shape [1, num_boxes, box coords] and
# the coordinates are ordered [ymin, xmin, ymax, xmax].
# Display the bounding box in the first thread only.
if not thread_id:
image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
bbox)
tf.image_summary('image_with_bounding_boxes', image_with_box)
# A large fraction of image datasets contain a human-annotated bounding
# box delineating the region of the image containing the object of interest.
# We choose to create a new bounding box for the object which is a randomly
# distorted version of the human-annotated bounding box that obeys an allowed
# range of aspect ratios, sizes and overlap with the human-annotated
# bounding box. If no box is supplied, then we assume the bounding box is
# the entire image.
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
tf.shape(image),
bounding_boxes=bbox,
min_object_covered=0.1,
aspect_ratio_range=[0.75, 1.33],
area_range=[0.05, 1.0],
max_attempts=100,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
if not thread_id:
image_with_distorted_box = tf.image.draw_bounding_boxes(
tf.expand_dims(image, 0), distort_bbox)
tf.image_summary('images_with_distorted_bounding_box',
image_with_distorted_box)
# Crop the image to the specified bounding box.
distorted_image = tf.slice(image, bbox_begin, bbox_size)
# This resizing operation may distort the images because the aspect
# ratio is not respected. We select a resize method in a round robin
# fashion based on the thread number.
# Note that ResizeMethod contains 4 enumerated resizing methods.
resize_method = thread_id % 4
distorted_image = tf.image.resize_images(distorted_image, height, width,
resize_method)
# Restore the shape since the dynamic slice based upon the bbox_size loses
# the third dimension.
distorted_image.set_shape([height, width, 3])
if not thread_id:
tf.image_summary('cropped_resized_image',
tf.expand_dims(distorted_image, 0))
# Randomly flip the image horizontally.
distorted_image = tf.image.random_flip_left_right(distorted_image)
# Randomly distort the colors.
distorted_image = distort_color(distorted_image, thread_id)
if not thread_id:
tf.image_summary('final_distorted_image',
tf.expand_dims(distorted_image, 0))
return distorted_image
def eval_image(image, height, width, scope=None):
"""Prepare one image for evaluation.
Args:
image: 3-D float Tensor
height: integer
width: integer
scope: Optional scope for op_scope.
Returns:
3-D float Tensor of prepared image.
"""
with tf.op_scope([image, height, width], scope, 'eval_image'):
# Crop the central region of the image with an area containing 87.5% of
# the original image.
image = tf.image.central_crop(image, central_fraction=0.875)
# Resize the image to the original height and width.
image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(image, [height, width],
align_corners=False)
image = tf.squeeze(image, [0])
return image
def image_preprocessing(image_buffer, bbox, train, thread_id=0):
"""Decode and preprocess one image for evaluation or training.
Args:
image_buffer: JPEG encoded string Tensor
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
train: boolean
thread_id: integer indicating preprocessing thread
Returns:
3-D float Tensor containing an appropriately scaled image
Raises:
ValueError: if user does not provide bounding box
"""
if bbox is None:
raise ValueError('Please supply a bounding box.')
image = decode_jpeg(image_buffer)
height = FLAGS.image_size
width = FLAGS.image_size
if train:
image = distort_image(image, height, width, bbox, thread_id)
else:
image = eval_image(image, height, width)
# Finally, rescale to [-1,1] instead of [0, 1)
image = tf.sub(image, 0.5)
image = tf.mul(image, 2.0)
return image
def parse_example_proto(example_serialized):
"""Parses an Example proto containing a training example of an image.
The output of the build_image_data.py image preprocessing script is a dataset
containing serialized Example protocol buffers. Each Example proto contains
the following fields:
image/height: 462
image/width: 581
image/colorspace: 'RGB'
image/channels: 3
image/class/label: 615
image/class/synset: 'n03623198'
image/class/text: 'knee pad'
image/object/bbox/xmin: 0.1
image/object/bbox/xmax: 0.9
image/object/bbox/ymin: 0.2
image/object/bbox/ymax: 0.6
image/object/bbox/label: 615
image/format: 'JPEG'
image/filename: 'ILSVRC2012_val_00041207.JPEG'
image/encoded: <JPEG encoded string>
Args:
example_serialized: scalar Tensor tf.string containing a serialized
Example protocol buffer.
Returns:
image_buffer: Tensor tf.string containing the contents of a JPEG file.
label: Tensor tf.int32 containing the label.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
text: Tensor tf.string containing the human-readable label.
"""
# Dense features in Example proto.
feature_map = {
'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
default_value=''),
'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
default_value=-1),
'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
default_value=''),
}
sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
# Sparse features in Example proto.
feature_map.update(
{k: sparse_float32 for k in ['image/object/bbox/xmin',
'image/object/bbox/ymin',
'image/object/bbox/xmax',
'image/object/bbox/ymax']})
features = tf.parse_single_example(example_serialized, feature_map)
label = tf.cast(features['image/class/label'], dtype=tf.int32)
xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
# Note that we impose an ordering of (y, x) just to make life difficult.
bbox = tf.concat(0, [ymin, xmin, ymax, xmax])
# Force the variable number of bounding boxes into the shape
# [1, num_boxes, coords].
bbox = tf.expand_dims(bbox, 0)
bbox = tf.transpose(bbox, [0, 2, 1])
return features['image/encoded'], label, bbox, features['image/class/text']
def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
"""Contruct batches of training or evaluation examples from the image dataset.
Args:
dataset: instance of Dataset class specifying the dataset.
See dataset.py for details.
batch_size: integer
train: boolean
num_preprocess_threads: integer, total number of preprocessing threads
Returns:
images: 4-D float Tensor of a batch of images
labels: 1-D integer Tensor of [batch_size].
Raises:
ValueError: if data is not found
"""
with tf.name_scope('batch_processing'):
data_files = dataset.data_files()
if data_files is None:
raise ValueError('No data files found for this dataset')
filename_queue = tf.train.string_input_producer(data_files, capacity=16)
if num_preprocess_threads is None:
num_preprocess_threads = FLAGS.num_preprocess_threads
if num_preprocess_threads % 4:
raise ValueError('Please make num_preprocess_threads a multiple '
'of 4 (%d % 4 != 0).', num_preprocess_threads)
# Create a subgraph with its own reader (but sharing the
# filename_queue) for each preprocessing thread.
images_and_labels = []
for thread_id in range(num_preprocess_threads):
reader = dataset.reader()
_, example_serialized = reader.read(filename_queue)
# Parse a serialized Example proto to extract the image and metadata.
image_buffer, label_index, bbox, _ = parse_example_proto(
example_serialized)
image = image_preprocessing(image_buffer, bbox, train, thread_id)
images_and_labels.append([image, label_index])
# Approximate number of examples per shard.
examples_per_shard = 1024
# Size the random shuffle queue to balance between good global
# mixing (more examples) and memory use (fewer examples).
# 1 image uses 299*299*3*4 bytes = 1MB
# The default input_queue_memory_factor is 16 implying a shuffling queue
# size: examples_per_shard * 16 * 1MB = 17.6GB
min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor
# Create a queue that produces the examples in batches after shuffling.
if train:
images, label_index_batch = tf.train.shuffle_batch_join(
images_and_labels,
batch_size=batch_size,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
else:
images, label_index_batch = tf.train.batch_join(
images_and_labels,
batch_size=batch_size,
capacity=min_queue_examples + 3 * batch_size)
# Reshape images into these desired dimensions.
height = FLAGS.image_size
width = FLAGS.image_size
depth = 3
images = tf.cast(images, tf.float32)
images = tf.reshape(images, shape=[batch_size, height, width, depth])
# Display the training images in the visualizer.
tf.image_summary('images', images)
return images, tf.reshape(label_index_batch, [batch_size])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment