README.md 36.9 KB
Newer Older
Martin Wicke's avatar
Martin Wicke committed
1
2
3
4
# Inception in TensorFlow

[ImageNet](http://www.image-net.org/) is a common academic data set in machine
learning for training an image recognition system. Code in this directory
5
6
7
demonstrates how to use TensorFlow to train and evaluate a type of convolutional
neural network (CNN) on this academic data set. In particular, we demonstrate
how to train the Inception v3 architecture as specified in:
Martin Wicke's avatar
Martin Wicke committed
8
9
10

_Rethinking the Inception Architecture for Computer Vision_

11
12
Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, Zbigniew
Wojna
Martin Wicke's avatar
Martin Wicke committed
13
14
15
16
17

http://arxiv.org/abs/1512.00567

This network achieves 21.2% top-1 and 5.6% top-5 error for single frame
evaluation with a computational cost of 5 billion multiply-adds per inference
18
19
and with using less than 25 million parameters. Below is a visualization of the
model architecture.
Martin Wicke's avatar
Martin Wicke committed
20
21
22
23
24

![Inception-v3 Architecture](g3doc/inception_v3_architecture.png)

## Description of Code

25
**NOTE**: For the most part, you will find a newer version of this code at [models/slim](https://github.com/tensorflow/models/tree/master/slim). In particular:
26
27
28
29

*   `inception_train.py` and `imagenet_train.py` should no longer be used. The slim editions for running on multiple GPUs are the current best examples.
*   `inception_distributed_train.py` and `imagenet_distributed_train.py` are still valid examples of distributed training.

Martin Wicke's avatar
Martin Wicke committed
30
31
The code base provides three core binaries for:

32
33
34
35
36
37
*   Training an Inception v3 network from scratch across multiple GPUs and/or
    multiple machines using the ImageNet 2012 Challenge training data set.
*   Evaluating an Inception v3 network using the ImageNet 2012 Challenge
    validation data set.
*   Retraining an Inception v3 network on a novel task and back-propagating the
    errors to fine tune the network weights.
Martin Wicke's avatar
Martin Wicke committed
38

Jack Zhang's avatar
Jack Zhang committed
39
The training procedure employs synchronous stochastic gradient descent across
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
multiple GPUs. The user may specify the number of GPUs they wish harness. The
synchronous training performs *batch-splitting* by dividing a given batch across
multiple GPUs.

The training set up is nearly identical to the section [Training a Model Using
Multiple GPU Cards]
(https://www.tensorflow.org/tutorials/deep_cnn/index.html#training-a-model-using-multiple-gpu-cards)
where we have substituted the CIFAR-10 model architecture with Inception v3. The
primary differences with that setup are:

*   Calculate and update the batch-norm statistics during training so that they
    may be substituted in during evaluation.
*   Specify the model architecture using a (still experimental) higher level
    language called TensorFlow-Slim.

For more details about TensorFlow-Slim, please see the [Slim README]
56
(inception/slim/README.md). Please note that this higher-level language is still
57
58
*experimental* and the API may change over time depending on usage and
subsequent research.
Martin Wicke's avatar
Martin Wicke committed
59
60
61

## Getting Started

62
63
64
65
66
Before you run the training script for the first time, you will need to download
and convert the ImageNet data to native TFRecord format. The TFRecord format
consists of a set of sharded files where each entry is a serialized `tf.Example`
proto. Each `tf.Example` proto contains the ImageNet image (JPEG encoded) as
well as metadata such as label and bounding box information. See
67
[`parse_example_proto`](inception/image_processing.py) for details.
68

69
We provide a single [script](inception/data/download_and_preprocess_imagenet.sh) for
70
71
72
73
74
75
76
77
78
79
80
downloading and converting ImageNet data to TFRecord format. Downloading and
preprocessing the data may take several hours (up to half a day) depending on
your network and computer speed. Please be patient.

To begin, you will need to sign up for an account with [ImageNet]
(http://image-net.org) to gain access to the data. Look for the sign up page,
create an account and request an access key to download the data.

After you have `USERNAME` and `PASSWORD`, you are ready to run our script. Make
sure that your hard disk has at least 500 GB of free space for downloading and
storing the data. Here we select `DATA_DIR=$HOME/imagenet-data` as such a
Martin Wicke's avatar
Martin Wicke committed
81
82
location but feel free to edit accordingly.

83
84
85
When you run the below script, please enter *USERNAME* and *PASSWORD* when
prompted. This will occur at the very beginning. Once these values are entered,
you will not need to interact with the script again.
Martin Wicke's avatar
Martin Wicke committed
86
87
88
89
90
91

```shell
# location of where to place the ImageNet data
DATA_DIR=$HOME/imagenet-data

# build the preprocessing script.
92
bazel build inception/download_and_preprocess_imagenet
Martin Wicke's avatar
Martin Wicke committed
93
94

# run it
95
bazel-bin/inception/download_and_preprocess_imagenet "${DATA_DIR}"
Martin Wicke's avatar
Martin Wicke committed
96
97
98
99
100
101
102
103
104
```

The final line of the output script should read:

```shell
2016-02-17 14:30:17.287989: Finished writing all 1281167 images in data set.
```

When the script finishes you will find 1024 and 128 training and validation
105
106
files in the `DATA_DIR`. The files will match the patterns `train-????-of-1024`
and `validation-?????-of-00128`, respectively.
Martin Wicke's avatar
Martin Wicke committed
107

108
109
[Congratulations!](https://www.youtube.com/watch?v=9bZkp7q19f0) You are now
ready to train or evaluate with the ImageNet data set.
Martin Wicke's avatar
Martin Wicke committed
110
111
112
113

## How to Train from Scratch

**WARNING** Training an Inception v3 network from scratch is a computationally
114
115
intensive task and depending on your compute setup may take several days or even
weeks.
Martin Wicke's avatar
Martin Wicke committed
116

117
118
119
120
121
122
*Before proceeding* please read the [Convolutional Neural Networks]
(https://www.tensorflow.org/tutorials/deep_cnn/index.html) tutorial in
particular focus on [Training a Model Using Multiple GPU Cards]
(https://www.tensorflow.org/tutorials/deep_cnn/index.html#training-a-model-using-multiple-gpu-cards)
. The model training method is nearly identical to that described in the
CIFAR-10 multi-GPU model training. Briefly, the model training
Martin Wicke's avatar
Martin Wicke committed
123

124
125
126
127
*   Places an individual model replica on each GPU. Split the batch across the
    GPUs.
*   Updates model parameters synchronously by waiting for all GPUs to finish
    processing a batch of data.
Martin Wicke's avatar
Martin Wicke committed
128
129

The training procedure is encapsulated by this diagram of how operations and
Jack Zhang's avatar
Jack Zhang committed
130
variables are placed on CPU and GPUs respectively.
Martin Wicke's avatar
Martin Wicke committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

<div style="width:40%; margin:auto; margin-bottom:10px; margin-top:20px;">
  <img style="width:100%" src="https://www.tensorflow.org/images/Parallelism.png">
</div>

Each tower computes the gradients for a portion of the batch and the gradients
are combined and averaged across the multiple towers in order to provide a
single update of the Variables stored on the CPU.

A crucial aspect of training a network of this size is *training speed* in terms
of wall-clock time. The training speed is dictated by many factors -- most
importantly the batch size and the learning rate schedule. Both of these
parameters are heavily coupled to the hardware set up.

Generally speaking, a batch size is a difficult parameter to tune as it requires
146
147
148
balancing memory demands of the model, memory available on the GPU and speed of
computation. Generally speaking, employing larger batch sizes leads to more
efficient computation and potentially more efficient training steps.
Martin Wicke's avatar
Martin Wicke committed
149
150
151
152
153

We have tested several hardware setups for training this model from scratch but
we emphasize that depending your hardware set up, you may need to adapt the
batch size and learning rate schedule.

154
Please see the comments in `inception_train.py` for a few selected learning rate
Martin Wicke's avatar
Martin Wicke committed
155
156
157
158
159
plans based on some selected hardware setups.

To train this model, you simply need to specify the following:

```shell
160
161
162
# Build the model. Note that we need to make sure the TensorFlow is ready to
# use before this as this command will not build TensorFlow.
bazel build inception/imagenet_train
Martin Wicke's avatar
Martin Wicke committed
163
164

# run it
swlsw's avatar
swlsw committed
165
bazel-bin/inception/imagenet_train --num_gpus=1 --batch_size=32 --train_dir=/tmp/imagenet_train --data_dir=/tmp/imagenet_data
Martin Wicke's avatar
Martin Wicke committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
```

The model reads in the ImageNet training data from `--data_dir`. If you followed
the instructions in [Getting Started](#getting-started), then set
`--data_dir="${DATA_DIR}"`. The script assumes that there exists a set of
sharded TFRecord files containing the ImageNet data. If you have not created
TFRecord files, please refer to [Getting Started](#getting-started)

Here is the output of the above command line when running on a Tesla K40c:

```shell
2016-03-07 12:24:59.922898: step 0, loss = 13.11 (5.3 examples/sec; 6.064 sec/batch)
2016-03-07 12:25:55.206783: step 10, loss = 13.71 (9.4 examples/sec; 3.394 sec/batch)
2016-03-07 12:26:28.905231: step 20, loss = 14.81 (9.5 examples/sec; 3.380 sec/batch)
2016-03-07 12:27:02.699719: step 30, loss = 14.45 (9.5 examples/sec; 3.378 sec/batch)
2016-03-07 12:27:36.515699: step 40, loss = 13.98 (9.5 examples/sec; 3.376 sec/batch)
2016-03-07 12:28:10.220956: step 50, loss = 13.92 (9.6 examples/sec; 3.327 sec/batch)
2016-03-07 12:28:43.658223: step 60, loss = 13.28 (9.6 examples/sec; 3.350 sec/batch)
...
```

187
188
189
In this example, a log entry is printed every 10 step and the line includes the
total loss (starts around 13.0-14.0) and the speed of processing in terms of
throughput (examples / sec) and batch speed (sec/batch).
Martin Wicke's avatar
Martin Wicke committed
190
191

The number of GPU devices is specified by `--num_gpus` (which defaults to 1).
192
193
Specifying `--num_gpus` greater then 1 splits the batch evenly split across the
GPU cards.
Martin Wicke's avatar
Martin Wicke committed
194
195

```shell
196
197
198
# Build the model. Note that we need to make sure the TensorFlow is ready to
# use before this as this command will not build TensorFlow.
bazel build inception/imagenet_train
Martin Wicke's avatar
Martin Wicke committed
199
200
201
202
203

# run it
bazel-bin/inception/imagenet_train --num_gpus=2 --batch_size=64 --train_dir=/tmp/imagenet_train
```

204
205
206
207
208
209
210
This model splits the batch of 64 images across 2 GPUs and calculates the
average gradient by waiting for both GPUs to finish calculating the gradients
from their respective data (See diagram above). Generally speaking, using larger
numbers of GPUs leads to higher throughput as well as the opportunity to use
larger batch sizes. In turn, larger batch sizes imply better estimates of the
gradient enabling the usage of higher learning rates. In summary, using more
GPUs results in simply faster training speed.
Martin Wicke's avatar
Martin Wicke committed
211
212

Note that selecting a batch size is a difficult parameter to tune as it requires
213
214
215
balancing memory demands of the model, memory available on the GPU and speed of
computation. Generally speaking, employing larger batch sizes leads to more
efficient computation and potentially more efficient training steps.
Martin Wicke's avatar
Martin Wicke committed
216
217
218
219
220
221
222
223
224
225

Note that there is considerable noise in the loss function on individual steps
in the previous log. Because of this noise, it is difficult to discern how well
a model is learning. The solution to the last problem is to launch TensorBoard
pointing to the directory containing the events log.

```shell
tensorboard --logdir=/tmp/imagenet_train
```

226
227
228
TensorBoard has access to the many Summaries produced by the model that describe
multitudes of statistics tracking the model behavior and the quality of the
learned model. In particular, TensorBoard tracks a exponentially smoothed
Martin Wicke's avatar
Martin Wicke committed
229
230
231
version of the loss. In practice, it is far easier to judge how well a model
learns by monitoring the smoothed version of the loss.

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
## How to Train from Scratch in a Distributed Setting

**NOTE** Distributed TensorFlow requires version 0.8 or later.

Distributed TensorFlow lets us use multiple machines to train a model faster.
This is quite different from the training with multiple GPU towers on a single
machine where all parameters and gradients computation are in the same place. We
coordinate the computation across multiple machines by employing a centralized
repository for parameters that maintains a unified, single copy of model
parameters. Each individual machine sends gradient updates to the centralized
parameter repository which coordinates these updates and sends back updated
parameters to the individual machines running the model training.

We term each machine that runs a copy of the training a `worker` or `replica`.
We term each machine that maintains model parameters a `ps`, short for
`parameter server`. Note that we might have more than one machine acting as a
`ps` as the model parameters may be sharded across multiple machines.

Variables may be updated with synchronous or asynchronous gradient updates. One
may construct a an [`Optimizer`]
(https://www.tensorflow.org/api_docs/python/train.html#optimizers) in TensorFlow
that constructs the necessary graph for either case diagrammed below from
TensorFlow [Whitepaper]
Olivia's avatar
Olivia committed
255
(http://download.tensorflow.org/paper/whitepaper2015.pdf):
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380

<div style="width:40%; margin:auto; margin-bottom:10px; margin-top:20px;">
  <img style="width:100%"
  src="https://www.tensorflow.org/images/tensorflow_figure7.png">
</div>

In [a recent paper](https://arxiv.org/abs/1604.00981), synchronous gradient
updates have demonstrated to reach higher accuracy in a shorter amount of time.
In this distributed Inception example we employ synchronous gradient updates.

Note that in this example each replica has a single tower that uses one GPU.

The command-line flags `worker_hosts` and `ps_hosts` specify available servers.
The same binary will be used for both the `worker` jobs and the `ps` jobs.
Command line flag `job_name` will be used to specify what role a task will be
playing and `task_id` will be used to idenify which one of the jobs it is
running. Several things to note here:

*   The numbers of `ps` and `worker` tasks are inferred from the lists of hosts
    specified in the flags. The `task_id` should be within the range `[0,
    num_ps_tasks)` for `ps` tasks and `[0, num_worker_tasks)` for `worker`
    tasks.
*   `ps` and `worker` tasks can run on the same machine, as long as that machine
    has sufficient resources to handle both tasks. Note that the `ps` task does
    not benefit from a GPU, so it should not attempt to use one (see below).
*   Multiple `worker` tasks can run on the same machine with multiple GPUs so
    machine_A with 2 GPUs may have 2 workers while machine_B with 1 GPU just has
    1 worker.
*   The default learning rate schedule works well for a wide range of number of
    replicas [25, 50, 100] but feel free to tune it for even better results.
*   The command line of both `ps` and `worker` tasks should include the complete
    list of `ps_hosts` and `worker_hosts`.
*   There is a chief `worker` among all workers which defaults to `worker` 0.
    The chief will be in charge of initializing all the parameters, writing out
    the summaries and the checkpoint. The checkpoint and summary will be in the
    `train_dir` of the host for `worker` 0.
*   Each worker processes a batch_size number of examples but each gradient
    update is computed from all replicas. Hence, the effective batch size of
    this model is batch_size * num_workers.

```shell
# Build the model. Note that we need to make sure the TensorFlow is ready to
# use before this as this command will not build TensorFlow.
bazel build inception/imagenet_distributed_train

# To start worker 0, go to the worker0 host and run the following (Note that
# task_id should be in the range [0, num_worker_tasks):
bazel-bin/inception/imagenet_distributed_train \
--batch_size=32 \
--data_dir=$HOME/imagenet-data \
--job_name='worker' \
--task_id=0 \
--ps_hosts='ps0.example.com:2222' \
--worker_hosts='worker0.example.com:2222,worker1.example.com:2222'

# To start worker 1, go to the worker1 host and run the following (Note that
# task_id should be in the range [0, num_worker_tasks):
bazel-bin/inception/imagenet_distributed_train \
--batch_size=32 \
--data_dir=$HOME/imagenet-data \
--job_name='worker' \
--task_id=1 \
--ps_hosts='ps0.example.com:2222' \
--worker_hosts='worker0.example.com:2222,worker1.example.com:2222'

# To start the parameter server (ps), go to the ps host and run the following (Note
# that task_id should be in the range [0, num_ps_tasks):
bazel-bin/inception/imagenet_distributed_train \
--job_name='ps' \
--task_id=0 \
--ps_hosts='ps0.example.com:2222' \
--worker_hosts='worker0.example.com:2222,worker1.example.com:2222'
```

If you have installed a GPU-compatible version of TensorFlow, the `ps` will also
try to allocate GPU memory although it is not helpful. This could potentially
crash the worker on the same machine as it has little to no GPU memory to
allocate. To avoid this, you can prepend the previous command to start `ps`
with: `CUDA_VISIBLE_DEVICES=''`

```shell
CUDA_VISIBLE_DEVICES='' bazel-bin/inception/imagenet_distributed_train \
--job_name='ps' \
--task_id=0 \
--ps_hosts='ps0.example.com:2222' \
--worker_hosts='worker0.example.com:2222,worker1.example.com:2222'
```

If you have run everything correctly, you should see a log in each `worker` job
that looks like the following. Note the training speed varies depending on your
hardware and the first several steps could take much longer.

```shell
INFO:tensorflow:PS hosts are: ['ps0.example.com:2222', 'ps1.example.com:2222']
INFO:tensorflow:Worker hosts are: ['worker0.example.com:2222', 'worker1.example.com:2222']
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:206] Initialize HostPortsGrpcChannelCache for job ps -> {ps0.example.com:2222, ps1.example.com:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:206] Initialize HostPortsGrpcChannelCache for job worker -> {localhost:2222, worker1.example.com:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:202] Started server with target: grpc://localhost:2222
INFO:tensorflow:Created variable global_step:0 with shape () and init <function zeros_initializer at 0x7f6aa014b140>

...

INFO:tensorflow:Created variable logits/logits/biases:0 with shape (1001,) and init <function _initializer at 0x7f6a77f3cf50>
INFO:tensorflow:SyncReplicas enabled: replicas_to_aggregate=2; total_num_replicas=2
INFO:tensorflow:2016-04-13 01:56:26.405639 Supervisor
INFO:tensorflow:Started 2 queues for processing input data.
INFO:tensorflow:global_step/sec: 0
INFO:tensorflow:Worker 0: 2016-04-13 01:58:40.342404: step 0, loss = 12.97(0.0 examples/sec; 65.428  sec/batch)
INFO:tensorflow:global_step/sec: 0.0172907
...
```

and a log in each `ps` job that looks like the following:

```shell
INFO:tensorflow:PS hosts are: ['ps0.example.com:2222', 'ps1.example.com:2222']
INFO:tensorflow:Worker hosts are: ['worker0.example.com:2222', 'worker1.example.com:2222']
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:206] Initialize HostPortsGrpcChannelCache for job ps -> {localhost:2222, ps1.example.com:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:206] Initialize HostPortsGrpcChannelCache for job worker -> {worker0.example.com:2222, worker1.example.com:2222}
I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:202] Started server with target: grpc://localhost:2222
```

[Congratulations!](https://www.youtube.com/watch?v=9bZkp7q19f0) You are now
training Inception in a distributed manner.

Martin Wicke's avatar
Martin Wicke committed
381
382
383
384
385
386
387
## How to Evaluate

Evaluating an Inception v3 model on the ImageNet 2012 validation data set
requires running a separate binary.

The evaluation procedure is nearly identical to [Evaluating a Model]
(https://www.tensorflow.org/tutorials/deep_cnn/index.html#evaluating-a-model)
388
389
described in the [Convolutional Neural Network]
(https://www.tensorflow.org/tutorials/deep_cnn/index.html) tutorial.
Martin Wicke's avatar
Martin Wicke committed
390

391
392
393
394
**WARNING** Be careful not to run the evaluation and training binary on the same
GPU or else you might run out of memory. Consider running the evaluation on a
separate GPU if available or suspending the training binary while running the
evaluation on the same GPU.
Martin Wicke's avatar
Martin Wicke committed
395
396
397
398

Briefly, one can evaluate the model by running:

```shell
399
400
401
# Build the model. Note that we need to make sure the TensorFlow is ready to
# use before this as this command will not build TensorFlow.
bazel build inception/imagenet_eval
Martin Wicke's avatar
Martin Wicke committed
402
403
404
405
406

# run it
bazel-bin/inception/imagenet_eval --checkpoint_dir=/tmp/imagenet_train --eval_dir=/tmp/imagenet_eval
```

407
408
Note that we point `--checkpoint_dir` to the location of the checkpoints saved
by `inception_train.py` above. Running the above command results in the
Martin Wicke's avatar
Martin Wicke committed
409
410
411
412
413
414
415
416
417
following output:

```shell
2016-02-17 22:32:50.391206: precision @ 1 = 0.735
...
```

The script calculates the precision @ 1 over the entire validation data
periodically. The precision @ 1 measures the how often the highest scoring
418
419
420
prediction from the model matched the ImageNet label -- in this case, 73.5%. If
you wish to run the eval just once and not periodically, append the `--run_once`
option.
Martin Wicke's avatar
Martin Wicke committed
421

422
423
424
425
Much like the training script, `imagenet_eval.py` also exports summaries that
may be visualized in TensorBoard. These summaries calculate additional
statistics on the predictions (e.g. recall @ 5) as well as monitor the
statistics of the model activations and weights during evaluation.
Martin Wicke's avatar
Martin Wicke committed
426
427
428
429

## How to Fine-Tune a Pre-Trained Model on a New Task

### Getting Started
430

Martin Wicke's avatar
Martin Wicke committed
431
432
433
Much like training the ImageNet model we must first convert a new data set to
the sharded TFRecord format which each entry is a serialized `tf.Example` proto.

434
435
We have provided a script demonstrating how to do this for small data set of of
a few thousand flower images spread across 5 labels:
Martin Wicke's avatar
Martin Wicke committed
436
437
438
439

```shell
daisy, dandelion, roses, sunflowers, tulips
```
440
441
442
443
444

There is a single automated script that downloads the data set and converts it
to the TFRecord format. Much like the ImageNet data set, each record in the
TFRecord format is a serialized `tf.Example` proto whose entries include a
JPEG-encoded string and an integer label. Please see [`parse_example_proto`]
445
(inception/image_processing.py) for details.
Martin Wicke's avatar
Martin Wicke committed
446
447
448

The script just takes a few minutes to run depending your network connection
speed for downloading and processing the images. Your hard disk requires 200MB
449
450
of free storage. Here we select `DATA_DIR=$HOME/flowers-data` as such a location
but feel free to edit accordingly.
Martin Wicke's avatar
Martin Wicke committed
451
452
453
454
455
456

```shell
# location of where to place the flowers data
FLOWERS_DATA_DIR=$HOME/flowers-data

# build the preprocessing script.
457
bazel build inception/download_and_preprocess_flowers
Martin Wicke's avatar
Martin Wicke committed
458
459

# run it
LiberiFatali's avatar
LiberiFatali committed
460
bazel-bin/inception/download_and_preprocess_flowers "${FLOWERS_DATA_DIR}"
Martin Wicke's avatar
Martin Wicke committed
461
462
463
464
465
466
467
468
469
470
```

If the script runs successfully, the final line of the terminal output should
look like:

```shell
2016-02-24 20:42:25.067551: Finished writing all 3170 images in data set.
```

When the script finishes you will find 2 shards for the training and validation
postmasters's avatar
postmasters committed
471
472
files in the `DATA_DIR`. The files will match the patterns `train-?????-of-00002`
and `validation-?????-of-00002`, respectively.
Martin Wicke's avatar
Martin Wicke committed
473
474

**NOTE** If you wish to prepare a custom image data set for transfer learning,
475
you will need to invoke [`build_image_data.py`](inception/data/build_image_data.py) on
476
477
your custom data set. Please see the associated options and assumptions behind
this script by reading the comments section of [`build_image_data.py`]
478
(inception/data/build_image_data.py). Also, if your custom data has a different
479
number of examples or classes, you need to change the appropriate values in
480
[`imagenet_data.py`](inception/imagenet_data.py).
Martin Wicke's avatar
Martin Wicke committed
481
482

The second piece you will need is a trained Inception v3 image model. You have
483
484
485
the option of either training one yourself (See [How to Train from Scratch]
(#how-to-train-from-scratch) for details) or you can download a pre-trained
model like so:
Martin Wicke's avatar
Martin Wicke committed
486
487
488
489

```shell
# location of where to place the Inception v3 model
DATA_DIR=$HOME/inception-v3-model
490
mkdir -p ${DATA_DIR}
Martin Wicke's avatar
Martin Wicke committed
491
492
493
494
495
496
497
498
499
500
501
502
503
cd ${DATA_DIR}

# download the Inception v3 model
curl -O http://download.tensorflow.org/models/image/imagenet/inception-v3-2016-03-01.tar.gz
tar xzf inception-v3-2016-03-01.tar.gz

# this will create a directory called inception-v3 which contains the following files.
> ls inception-v3
README.txt
checkpoint
model.ckpt-157585
```

504
505
[Congratulations!](https://www.youtube.com/watch?v=9bZkp7q19f0) You are now
ready to fine-tune your pre-trained Inception v3 model with the flower data set.
Martin Wicke's avatar
Martin Wicke committed
506
507
508

### How to Retrain a Trained Model on the Flowers Data

509
510
We are now ready to fine-tune a pre-trained Inception-v3 model on the flowers
data set. This requires two distinct changes to our training procedure:
Martin Wicke's avatar
Martin Wicke committed
511

512
513
1.  Build the exact same model as previously except we change the number of
    labels in the final classification layer.
Martin Wicke's avatar
Martin Wicke committed
514

515
516
2.  Restore all weights from the pre-trained Inception-v3 except for the final
    classification layer; this will get randomly initialized instead.
Martin Wicke's avatar
Martin Wicke committed
517
518

We can perform these two operations by specifying two flags:
519
520
521
522
`--pretrained_model_checkpoint_path` and `--fine_tune`. The first flag is a
string that points to the path of a pre-trained Inception-v3 model. If this flag
is specified, it will load the entire model from the checkpoint before the
script begins training.
Martin Wicke's avatar
Martin Wicke committed
523
524

The second flag `--fine_tune` is a boolean that indicates whether the last
525
526
527
528
classification layer should be randomly initialized or restored. You may set
this flag to false if you wish to continue training a pre-trained model from a
checkpoint. If you set this flag to true, you can train a new classification
layer from scratch.
Martin Wicke's avatar
Martin Wicke committed
529

530
In order to understand how `--fine_tune` works, please see the discussion on
531
`Variables` in the TensorFlow-Slim [`README.md`](inception/slim/README.md).
Martin Wicke's avatar
Martin Wicke committed
532

533
534
Putting this all together you can retrain a pre-trained Inception-v3 model on
the flowers data set with the following command.
Martin Wicke's avatar
Martin Wicke committed
535
536

```shell
537
538
539
# Build the model. Note that we need to make sure the TensorFlow is ready to
# use before this as this command will not build TensorFlow.
bazel build inception/flowers_train
Martin Wicke's avatar
Martin Wicke committed
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562

# Path to the downloaded Inception-v3 model.
MODEL_PATH="${INCEPTION_MODEL_DIR}/model.ckpt-157585"

# Directory where the flowers data resides.
FLOWERS_DATA_DIR=/tmp/flowers-data/

# Directory where to save the checkpoint and events files.
TRAIN_DIR=/tmp/flowers_train/

# Run the fine-tuning on the flowers data set starting from the pre-trained
# Imagenet-v3 model.
bazel-bin/inception/flowers_train \
  --train_dir="${TRAIN_DIR}" \
  --data_dir="${FLOWERS_DATA_DIR}" \
  --pretrained_model_checkpoint_path="${MODEL_PATH}" \
  --fine_tune=True \
  --initial_learning_rate=0.001 \
  --input_queue_memory_factor=1
```

We have added a few extra options to the training procedure.

563
564
565
566
567
*   Fine-tuning a model a separate data set requires significantly lowering the
    initial learning rate. We set the initial learning rate to 0.001.
*   The flowers data set is quite small so we shrink the size of the shuffling
    queue of examples. See [Adjusting Memory Demands](#adjusting-memory-demands)
    for more details.
Martin Wicke's avatar
Martin Wicke committed
568
569
570
571
572

The training script will only reports the loss. To evaluate the quality of the
fine-tuned model, you will need to run `flowers_eval`:

```shell
573
574
575
# Build the model. Note that we need to make sure the TensorFlow is ready to
# use before this as this command will not build TensorFlow.
bazel build inception/flowers_eval
Martin Wicke's avatar
Martin Wicke committed
576
577
578
579
580
581
582
583
584
585
586

# Directory where we saved the fine-tuned checkpoint and events files.
TRAIN_DIR=/tmp/flowers_train/

# Directory where the flowers data resides.
FLOWERS_DATA_DIR=/tmp/flowers-data/

# Directory where to save the evaluation events files.
EVAL_DIR=/tmp/flowers_eval/

# Evaluate the fine-tuned model on a hold-out of the flower data set.
Weilin Xu's avatar
Weilin Xu committed
587
bazel-bin/inception/flowers_eval \
Martin Wicke's avatar
Martin Wicke committed
588
589
590
591
592
  --eval_dir="${EVAL_DIR}" \
  --data_dir="${FLOWERS_DATA_DIR}" \
  --subset=validation \
  --num_examples=500 \
  --checkpoint_dir="${TRAIN_DIR}" \
593
  --input_queue_memory_factor=1 \
Martin Wicke's avatar
Martin Wicke committed
594
595
596
  --run_once
```

597
598
We find that the evaluation arrives at roughly 93.4% precision@1 after the model
has been running for 2000 steps.
Martin Wicke's avatar
Martin Wicke committed
599
600

```shell
Neal Wu's avatar
Neal Wu committed
601
Successfully loaded model from /tmp/flowers/model.ckpt-1999 at step=1999.
Martin Wicke's avatar
Martin Wicke committed
602
603
604
605
606
607
608
2016-03-01 16:52:51.761219: starting evaluation on (validation).
2016-03-01 16:53:05.450419: [20 batches out of 20] (36.5 examples/sec; 0.684sec/batch)
2016-03-01 16:53:05.450471: precision @ 1 = 0.9340 recall @ 5 = 0.9960 [500 examples]
```

## How to Construct a New Dataset for Retraining

609
610
One can use the existing scripts supplied with this model to build a new dataset
for training or fine-tuning. The main script to employ is
611
[`build_image_data.py`](inception/data/build_image_data.py). Briefly, this script takes a
612
613
structured directory of images and converts it to a sharded `TFRecord` that can
be read by the Inception model.
Martin Wicke's avatar
Martin Wicke committed
614

615
616
In particular, you will need to create a directory of training images that
reside within `$TRAIN_DIR` and `$VALIDATION_DIR` arranged as such:
Martin Wicke's avatar
Martin Wicke committed
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635

```shell
  $TRAIN_DIR/dog/image0.jpeg
  $TRAIN_DIR/dog/image1.jpg
  $TRAIN_DIR/dog/image2.png
  ...
  $TRAIN_DIR/cat/weird-image.jpeg
  $TRAIN_DIR/cat/my-image.jpeg
  $TRAIN_DIR/cat/my-image.JPG
  ...
  $VALIDATION_DIR/dog/imageA.jpeg
  $VALIDATION_DIR/dog/imageB.jpg
  $VALIDATION_DIR/dog/imageC.png
  ...
  $VALIDATION_DIR/cat/weird-image.PNG
  $VALIDATION_DIR/cat/that-image.jpg
  $VALIDATION_DIR/cat/cat.JPG
  ...
```
Neal Wu's avatar
Neal Wu committed
636
637
638
639
**NOTE**: This script will append an extra background class indexed at 0, so
your class labels will range from 0 to num_labels. Using the example above, the
corresponding class labels generated from `build_image_data.py` will be as
follows:
640
641
642
643
644
```shell
0
1 dog
2 cat
```
645
646
647
648

Each sub-directory in `$TRAIN_DIR` and `$VALIDATION_DIR` corresponds to a unique
label for the images that reside within that sub-directory. The images may be
JPEG or PNG images. We do not support other images types currently.
Martin Wicke's avatar
Martin Wicke committed
649
650

Once the data is arranged in this directory structure, we can run
651
652
653
654
`build_image_data.py` on the data to generate the sharded `TFRecord` dataset.
Each entry of the `TFRecord` is a serialized `tf.Example` protocol buffer. A
complete list of information contained in the `tf.Example` is described in the
comments of `build_image_data.py`.
Martin Wicke's avatar
Martin Wicke committed
655

656
To run `build_image_data.py`, you can run the following command line:
Martin Wicke's avatar
Martin Wicke committed
657
658
659
660
661
662

```shell
# location to where to save the TFRecord data.
OUTPUT_DIRECTORY=$HOME/my-custom-data/

# build the preprocessing script.
663
bazel build inception/build_image_data
Martin Wicke's avatar
Martin Wicke committed
664
665
666
667
668
669
670

# convert the data.
bazel-bin/inception/build_image_data \
  --train_directory="${TRAIN_DIR}" \
  --validation_directory="${VALIDATION_DIR}" \
  --output_directory="${OUTPUT_DIRECTORY}" \
  --labels_file="${LABELS_FILE}" \
671
672
  --train_shards=128 \
  --validation_shards=24 \
Martin Wicke's avatar
Martin Wicke committed
673
674
  --num_threads=8
```
675

Martin Wicke's avatar
Martin Wicke committed
676
where the `$OUTPUT_DIRECTORY` is the location of the sharded `TFRecords`. The
Anthony Tatowicz's avatar
Anthony Tatowicz committed
677
`$LABELS_FILE` will be a text file that is read by the script that provides
678
679
a list of all of the labels. For instance, in the case flowers data set, the
`$LABELS_FILE` contained the following data:
Martin Wicke's avatar
Martin Wicke committed
680
681
682
683
684
685
686
687
688
689

```shell
daisy
dandelion
roses
sunflowers
tulips
```

Note that each row of each label corresponds with the entry in the final
690
691
692
classifier in the model. That is, the `daisy` corresponds to the classifier for
entry `1`; `dandelion` is entry `2`, etc. We skip label `0` as a background
class.
Martin Wicke's avatar
Martin Wicke committed
693
694
695
696

After running this script produces files that look like the following:

```shell
697
698
  $TRAIN_DIR/train-00000-of-00128
  $TRAIN_DIR/train-00001-of-00128
Martin Wicke's avatar
Martin Wicke committed
699
  ...
700
  $TRAIN_DIR/train-00127-of-00128
Martin Wicke's avatar
Martin Wicke committed
701
702
703

and

704
705
  $VALIDATION_DIR/validation-00000-of-00024
  $VALIDATION_DIR/validation-00001-of-00024
Martin Wicke's avatar
Martin Wicke committed
706
  ...
707
  $VALIDATION_DIR/validation-00023-of-00024
Martin Wicke's avatar
Martin Wicke committed
708
```
709

710
where 128 and 24 are the number of shards specified for each dataset,
711
respectively. Generally speaking, we aim for selecting the number of shards such
Glen Baker's avatar
Glen Baker committed
712
that roughly 1024 images reside in each shard. Once this data set is built, you
713
are ready to train or fine-tune an Inception model on this data set.
Martin Wicke's avatar
Martin Wicke committed
714

715
716
Note, if you are piggy backing on the flowers retraining scripts, be sure to
update `num_classes()` and `num_examples_per_epoch()` in `flowers_data.py`
717
718
to correspond with your data.

Martin Wicke's avatar
Martin Wicke committed
719
720
721
## Practical Considerations for Training a Model

The model architecture and training procedure is heavily dependent on the
722
723
724
725
hardware used to train the model. If you wish to train or fine-tune this model
on your machine **you will need to adjust and empirically determine a good set
of training hyper-parameters for your setup**. What follows are some general
considerations for novices.
Martin Wicke's avatar
Martin Wicke committed
726
727
728

### Finding Good Hyperparameters

729
730
Roughly 5-10 hyper-parameters govern the speed at which a network is trained. In
addition to `--batch_size` and `--num_gpus`, there are several constants defined
731
in [inception_train.py](inception/inception_train.py) which dictate the learning
732
schedule.
Martin Wicke's avatar
Martin Wicke committed
733
734
735
736
737
738
739
740
741
742

```shell
RMSPROP_DECAY = 0.9                # Decay term for RMSProp.
MOMENTUM = 0.9                     # Momentum in RMSProp.
RMSPROP_EPSILON = 1.0              # Epsilon term for RMSProp.
INITIAL_LEARNING_RATE = 0.1        # Initial learning rate.
NUM_EPOCHS_PER_DECAY = 30.0        # Epochs after which learning rate decays.
LEARNING_RATE_DECAY_FACTOR = 0.16  # Learning rate decay factor.
```

Jack Zhang's avatar
Jack Zhang committed
743
There are many papers that discuss the various tricks and trade-offs associated
Martin Wicke's avatar
Martin Wicke committed
744
745
746
with training a model with stochastic gradient descent. For those new to the
field, some great references are:

747
748
749
750
*   Y Bengio, [Practical recommendations for gradient-based training of deep
    architectures](http://arxiv.org/abs/1206.5533)
*   I Goodfellow, Y Bengio and A Courville, [Deep Learning]
    (http://www.deeplearningbook.org/)
Martin Wicke's avatar
Martin Wicke committed
751
752

What follows is a summary of some general advice for identifying appropriate
753
754
model hyper-parameters in the context of this particular model training setup.
Namely, this library provides *synchronous* updates to model parameters based on
Martin Wicke's avatar
Martin Wicke committed
755
756
batch-splitting the model across multiple GPUs.

757
758
759
*   Higher learning rates leads to faster training. Too high of learning rate
    leads to instability and will cause model parameters to diverge to infinity
    or NaN.
Martin Wicke's avatar
Martin Wicke committed
760

761
762
*   Larger batch sizes lead to higher quality estimates of the gradient and
    permit training the model with higher learning rates.
Martin Wicke's avatar
Martin Wicke committed
763

764
765
766
*   Often the GPU memory is a bottleneck that prevents employing larger batch
    sizes. Employing more GPUs allows one to user larger batch sizes because
    this model splits the batch across the GPUs.
Martin Wicke's avatar
Martin Wicke committed
767
768
769

**NOTE** If one wishes to train this model with *asynchronous* gradient updates,
one will need to substantially alter this model and new considerations need to
770
771
be factored into hyperparameter tuning. See [Large Scale Distributed Deep
Networks](http://research.google.com/archive/large_deep_networks_nips2012.html)
Martin Wicke's avatar
Martin Wicke committed
772
773
774
775
776
777
778
779
780
for a discussion in this domain.

### Adjusting Memory Demands

Training this model has large memory demands in terms of the CPU and GPU. Let's
discuss each item in turn.

GPU memory is relatively small compared to CPU memory. Two items dictate the
amount of GPU memory employed -- model architecture and batch size. Assuming
781
782
783
that you keep the model architecture fixed, the sole parameter governing the GPU
demand is the batch size. A good rule of thumb is to try employ as large of
batch size as will fit on the GPU.
Martin Wicke's avatar
Martin Wicke committed
784
785
786
787
788
789

If you run out of GPU memory, either lower the `--batch_size` or employ more
GPUs on your desktop. The model performs batch-splitting across GPUs, thus N
GPUs can handle N times the batch size of 1 GPU.

The model requires a large amount of CPU memory as well. We have tuned the model
790
791
to employ about ~20GB of CPU memory. Thus, having access to about 40 GB of CPU
memory would be ideal.
Martin Wicke's avatar
Martin Wicke committed
792

793
794
795
796
797
If that is not possible, you can tune down the memory demands of the model via
lowering `--input_queue_memory_factor`. Images are preprocessed asynchronously
with respect to the main training across `--num_preprocess_threads` threads. The
preprocessed images are stored in shuffling queue in which each GPU performs a
dequeue operation in order to receive a `batch_size` worth of images.
Martin Wicke's avatar
Martin Wicke committed
798
799
800

In order to guarantee good shuffling across the data, we maintain a large
shuffling queue of 1024 x `input_queue_memory_factor` images. For the current
801
802
803
804
model architecture, this corresponds to about 4GB of CPU memory. You may lower
`input_queue_memory_factor` in order to decrease the memory footprint. Keep in
mind though that lowering this value drastically may result in a model with
slightly lower predictive accuracy when training from scratch. Please see
805
comments in [`image_processing.py`](inception/image_processing.py) for more details.
Martin Wicke's avatar
Martin Wicke committed
806
807
808
809
810

## Troubleshooting

#### The model runs out of CPU memory.

811
812
813
In lieu of buying more CPU memory, an easy fix is to decrease
`--input_queue_memory_factor`. See [Adjusting Memory Demands]
(#adjusting-memory-demands).
Martin Wicke's avatar
Martin Wicke committed
814
815
816
817

#### The model runs out of GPU memory.

The data is not able to fit on the GPU card. The simplest solution is to
818
819
decrease the batch size of the model. Otherwise, you will need to think about a
more sophisticated method for specifying the training which cuts up the model
Martin Wicke's avatar
Martin Wicke committed
820
across multiple `session.run()` calls or partitions the model across multiple
821
822
GPUs. See [Using GPUs](https://www.tensorflow.org/how_tos/using_gpu/index.html)
and [Adjusting Memory Demands](#adjusting-memory-demands) for more information.
Martin Wicke's avatar
Martin Wicke committed
823
824
825
826
827
828
829

#### The model training results in NaN's.

The learning rate of the model is too high. Turn down your learning rate.

#### I wish to train a model with a different image size.

830
831
832
833
834
The simplest solution is to artificially resize your images to `299x299` pixels.
See [Images](https://www.tensorflow.org/api_docs/python/image.html) section for
many resizing, cropping and padding methods. Note that the entire model
architecture is predicated on a `299x299` image, thus if you wish to change the
input image size, then you may need to redesign the entire model architecture.
Martin Wicke's avatar
Martin Wicke committed
835
836
837

#### What hardware specification are these hyper-parameters targeted for?

838
839
840
We targeted a desktop with 128GB of CPU ram connected to 8 NVIDIA Tesla K40 GPU
cards but we have run this on desktops with 32GB of CPU ram and 1 NVIDIA Tesla
K40. You can get a sense of the various training configurations we tested by
841
reading the comments in [`inception_train.py`](inception/inception_train.py).
Martin Wicke's avatar
Martin Wicke committed
842

843
#### How do I continue training from a checkpoint in distributed setting?
Martin Wicke's avatar
Martin Wicke committed
844

845
846
847
848
You only need to make sure that the checkpoint is in a location that can be
reached by all of the `ps` tasks. By specifying the checkpoint location with
`--train_dir` , the `ps` servers will load the checkpoint before commencing
training.