Commit 1abcb9c3 authored by Suharsh Sivakumar's avatar Suharsh Sivakumar
Browse files

Add mobilenet train and eval scripts and update documentation to include quantization support.

parent 9a9e4228
......@@ -429,6 +429,28 @@ py_library(
],
)
py_binary(
name = "mobilenet_v1_train",
srcs = ["nets/mobilenet_v1_train.py"],
deps = [
":dataset_factory",
":mobilenet_v1",
":preprocessing_factory",
# "//tensorflow"
],
)
py_binary(
name = "mobilenet_v1_eval",
srcs = ["nets/mobilenet_v1_eval.py"],
deps = [
":dataset_factory",
":mobilenet_v1",
":preprocessing_factory",
# "//tensorflow"
],
)
py_test(
name = "mobilenet_v1_test",
size = "large",
......
......@@ -256,9 +256,9 @@ Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy |
[ResNet V2 200](https://arxiv.org/abs/1603.05027)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v2.py)|[TBA]()|79.9\*|95.2\*|
[VGG 16](http://arxiv.org/abs/1409.1556.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py)|[vgg_16_2016_08_28.tar.gz](http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz)|71.5|89.8|
[VGG 19](http://arxiv.org/abs/1409.1556.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py)|[vgg_19_2016_08_28.tar.gz](http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz)|71.1|89.8|
[MobileNet_v1_1.0_224](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py)|[mobilenet_v1_1.0_224_2017_06_14.tar.gz](http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz)|70.7|89.5|
[MobileNet_v1_0.50_160](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py)|[mobilenet_v1_0.50_160_2017_06_14.tar.gz](http://download.tensorflow.org/models/mobilenet_v1_0.50_160_2017_06_14.tar.gz)|59.9|82.5|
[MobileNet_v1_0.25_128](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py)|[mobilenet_v1_0.25_128_2017_06_14.tar.gz](http://download.tensorflow.org/models/mobilenet_v1_0.25_128_2017_06_14.tar.gz)|41.3|66.2|
[MobileNet_v1_1.0_224](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py)|[mobilenet_v1_1.0_224.tgz](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz)|70.9|89.9|
[MobileNet_v1_0.50_160](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py)|[mobilenet_v1_0.50_160.tgz](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz)|59.1|81.9|
[MobileNet_v1_0.25_128](https://arxiv.org/pdf/1704.04861.pdf)|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.py)|[mobilenet_v1_0.25_128.tgz](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz)|41.5|66.3|
[NASNet-A_Mobile_224](https://arxiv.org/abs/1707.07012)#|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/nasnet.py)|[nasnet-a_mobile_04_10_2017.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|74.0|91.6|
[NASNet-A_Large_331](https://arxiv.org/abs/1707.07012)#|[Code](https://github.com/tensorflow/models/blob/master/research/slim/nets/nasnet/nasnet.py)|[nasnet-a_large_04_10_2017.tar.gz](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|82.7|96.2|
......@@ -269,7 +269,9 @@ reported on the ImageNet validation set.
(#) More information and details about the NASNet architectures are available at this [README](nets/nasnet/README.md)
All 16 MobileNet Models reported in the [MobileNet Paper](https://arxiv.org/abs/1704.04861) can be found [here](https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet_v1.md).
All 16 float MobileNet V1 models reported in the [MobileNet Paper](https://arxiv.org/abs/1704.04861) and all
16 quantized [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) compatible MobileNet V1 models can be found
[here](https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet_v1.md).
(\*): Results quoted from the [paper](https://arxiv.org/abs/1603.05027).
......@@ -515,3 +517,4 @@ image_preprocessing_fn = preprocessing_factory.get_preprocessing(
See
[Hardware Specifications](https://github.com/tensorflow/models/tree/master/research/inception#what-hardware-specification-are-these-hyper-parameters-targeted-for).
......@@ -12,36 +12,118 @@ Choose the right MobileNet model to fit your latency and size budget. The size o
[ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/)
image classification dataset. Accuracies were computed by evaluating using a single image crop.
Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy |
Model | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy |
:----:|:------------:|:----------:|:-------:|:-------:|
[MobileNet_v1_1.0_224](http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz)|569|4.24|70.7|89.5|
[MobileNet_v1_1.0_192](http://download.tensorflow.org/models/mobilenet_v1_1.0_192_2017_06_14.tar.gz)|418|4.24|69.3|88.9|
[MobileNet_v1_1.0_160](http://download.tensorflow.org/models/mobilenet_v1_1.0_160_2017_06_14.tar.gz)|291|4.24|67.2|87.5|
[MobileNet_v1_1.0_128](http://download.tensorflow.org/models/mobilenet_v1_1.0_128_2017_06_14.tar.gz)|186|4.24|64.1|85.3|
[MobileNet_v1_0.75_224](http://download.tensorflow.org/models/mobilenet_v1_0.75_224_2017_06_14.tar.gz)|317|2.59|68.4|88.2|
[MobileNet_v1_0.75_192](http://download.tensorflow.org/models/mobilenet_v1_0.75_192_2017_06_14.tar.gz)|233|2.59|67.4|87.3|
[MobileNet_v1_0.75_160](http://download.tensorflow.org/models/mobilenet_v1_0.75_160_2017_06_14.tar.gz)|162|2.59|65.2|86.1|
[MobileNet_v1_0.75_128](http://download.tensorflow.org/models/mobilenet_v1_0.75_128_2017_06_14.tar.gz)|104|2.59|61.8|83.6|
[MobileNet_v1_0.50_224](http://download.tensorflow.org/models/mobilenet_v1_0.50_224_2017_06_14.tar.gz)|150|1.34|64.0|85.4|
[MobileNet_v1_0.50_192](http://download.tensorflow.org/models/mobilenet_v1_0.50_192_2017_06_14.tar.gz)|110|1.34|62.1|84.0|
[MobileNet_v1_0.50_160](http://download.tensorflow.org/models/mobilenet_v1_0.50_160_2017_06_14.tar.gz)|77|1.34|59.9|82.5|
[MobileNet_v1_0.50_128](http://download.tensorflow.org/models/mobilenet_v1_0.50_128_2017_06_14.tar.gz)|49|1.34|56.2|79.6|
[MobileNet_v1_0.25_224](http://download.tensorflow.org/models/mobilenet_v1_0.25_224_2017_06_14.tar.gz)|41|0.47|50.6|75.0|
[MobileNet_v1_0.25_192](http://download.tensorflow.org/models/mobilenet_v1_0.25_192_2017_06_14.tar.gz)|34|0.47|49.0|73.6|
[MobileNet_v1_0.25_160](http://download.tensorflow.org/models/mobilenet_v1_0.25_160_2017_06_14.tar.gz)|21|0.47|46.0|70.7|
[MobileNet_v1_0.25_128](http://download.tensorflow.org/models/mobilenet_v1_0.25_128_2017_06_14.tar.gz)|14|0.47|41.3|66.2|
[MobileNet_v1_1.0_224](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz)|569|4.24|70.9|89.9|
[MobileNet_v1_1.0_192](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz)|418|4.24|70.0|89.2|
[MobileNet_v1_1.0_160](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz)|291|4.24|68.0|87.7|
[MobileNet_v1_1.0_128](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz)|186|4.24|65.2|85.8|
[MobileNet_v1_0.75_224](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz)|317|2.59|68.4|88.2|
[MobileNet_v1_0.75_192](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz)|233|2.59|67.2|87.3|
[MobileNet_v1_0.75_160](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz)|162|2.59|65.3|86.0|
[MobileNet_v1_0.75_128](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz)|104|2.59|62.1|83.9|
[MobileNet_v1_0.50_224](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz)|150|1.34|63.3|84.9|
[MobileNet_v1_0.50_192](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz)|110|1.34|61.7|83.6|
[MobileNet_v1_0.50_160](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz)|77|1.34|59.1|81.9|
[MobileNet_v1_0.50_128](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz)|49|1.34|56.3|79.4|
[MobileNet_v1_0.25_224](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz)|41|0.47|49.8|74.2|
[MobileNet_v1_0.25_192](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz)|34|0.47|47.7|72.3|
[MobileNet_v1_0.25_160](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz)|21|0.47|45.5|70.3|
[MobileNet_v1_0.25_128](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz)|14|0.47|41.5|66.3|
[MobileNet_v1_1.0_224_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224_quant.tgz)|569|4.24|69.7|89.5|
[MobileNet_v1_1.0_192_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192_quant.tgz)|418|4.24|69.0|88.9|
[MobileNet_v1_1.0_160_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160_quant.tgz)|291|4.24|67.3|87.7|
[MobileNet_v1_1.0_128_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128_quant.tgz)|186|4.24|64.0|85.5|
[MobileNet_v1_0.75_224_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224_quant.tgz)|317|2.59|67.9|88.1|
[MobileNet_v1_0.75_192_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192_quant.tgz)|233|2.59|66.2|87.1|
[MobileNet_v1_0.75_160_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160_quant.tgz)|162|2.59|63.9|85.5|
[MobileNet_v1_0.75_128_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128_quant.tgz)|104|2.59|59.8|82.8|
[MobileNet_v1_0.50_224_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224_quant.tgz)|150|1.34|62.2|84.5|
[MobileNet_v1_0.50_192_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192_quant.tgz)|110|1.34|60.4|83.2|
[MobileNet_v1_0.50_160_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160_quant.tgz)|77|1.34|57.7|81.3|
[MobileNet_v1_0.50_128_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128_quant.tgz)|49|1.34|54.9|78.9|
[MobileNet_v1_0.25_224_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224_quant.tgz)|41|0.47|48.2|73.8|
[MobileNet_v1_0.25_192_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192_quant.tgz)|34|0.47|45.8|71.9|
[MobileNet_v1_0.25_160_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160_quant.tgz)|21|0.47|43.5|69.1|
[MobileNet_v1_0.25_128_quant](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128_quant.tgz)|14|0.47|39.9|65.8|
The linked model tar files contain the following:
- Trained model checkpoints
- Eval graph text protos (to be easily viewed)
- Frozen trained models
- Info file containing input and output information
- Converted [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) flatbuffer model
Note that quantized model GraphDefs are still float models, they just have FakeQuantization
operation embedded to simulate quantization. These are converted by [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/)
to be fully quantized. The final effect of quantization can be seen by comparing the frozen fake
quantized graph to the size of the TFLite flatbuffer, i.e. The TFLite flatbuffer is about 1/4
the size.
For more information on the quantization techniques used here, see
[here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize).
Here is an example of how to download the MobileNet_v1_1.0_224 checkpoint:
```shell
$ CHECKPOINT_DIR=/tmp/checkpoints
$ mkdir ${CHECKPOINT_DIR}
$ wget http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz
$ tar -xvf mobilenet_v1_1.0_224_2017_06_14.tar.gz
$ wget http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz
$ tar -xvf mobilenet_v1_1.0_224.tar.gz
$ mv mobilenet_v1_1.0_224.ckpt.* ${CHECKPOINT_DIR}
$ rm mobilenet_v1_1.0_224_2017_06_14.tar.gz
```
More information on integrating MobileNets into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md).
To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/).
# MobileNet V1 scripts
This package contains scripts for training floating point and eight-bit fixed
point TensorFlow models.
Quantization tools used are described in [contrib/quantize](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/quantize).
Conversion to fully quantized models for mobile can be done through [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/).
## Usage
### Build for GPU
```
$ bazel build -c opt --config=cuda mobilenet_v1_{eval,train}
```
### Running
#### Float Training and Eval
Train:
```
$ ./bazel-bin/mobilenet_v1_train
```
Eval:
```
$ ./bazel-bin/mobilenet_v1_eval
```
#### Quantized Training and Eval
Train from preexisting float checkpoint:
```
$ ./bazel-bin/mobilenet_v1_train --quantize=True --fine_tune_checkpoint=checkpoint-name
```
Train from scratch:
```
$ ./bazel-bin/mobilenet_v1_train --quantize=True
```
Eval:
```
$ ./bazel-bin/mobilenet_v1_eval --quantize=True
```
The resulting float and quantized models can be run on-device via [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/).
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Validate mobilenet_v1 with options for quantization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
from datasets import dataset_factory
from nets import mobilenet_v1
from preprocessing import preprocessing_factory
slim = tf.contrib.slim
flags = tf.app.flags
flags.DEFINE_string('master', '', 'Session master')
flags.DEFINE_integer('batch_size', 250, 'Batch size')
flags.DEFINE_integer('num_classes', 1001, 'Number of classes to distinguish')
flags.DEFINE_integer('num_examples', 50000, 'Number of examples to evaluate')
flags.DEFINE_integer('image_size', 224, 'Input image resolution')
flags.DEFINE_float('depth_multiplier', 1.0, 'Depth multiplier for mobilenet')
flags.DEFINE_bool('quantize', False, 'Quantize training')
flags.DEFINE_string('checkpoint_dir', '', 'The directory for checkpoints')
flags.DEFINE_string('eval_dir', '', 'Directory for writing eval event logs')
flags.DEFINE_string('dataset_dir', '', 'Location of dataset')
FLAGS = flags.FLAGS
def imagenet_input(is_training):
"""Data reader for imagenet.
Reads in imagenet data and performs pre-processing on the images.
Args:
is_training: bool specifying if train or validation dataset is needed.
Returns:
A batch of images and labels.
"""
if is_training:
dataset = dataset_factory.get_dataset('imagenet', 'train',
FLAGS.dataset_dir)
else:
dataset = dataset_factory.get_dataset('imagenet', 'validation',
FLAGS.dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=is_training,
common_queue_capacity=2 * FLAGS.batch_size,
common_queue_min=FLAGS.batch_size)
[image, label] = provider.get(['image', 'label'])
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
'mobilenet_v1', is_training=is_training)
image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
images, labels = tf.train.batch(
tensors=[image, label],
batch_size=FLAGS.batch_size,
num_threads=4,
capacity=5 * FLAGS.batch_size)
return images, labels
def metrics(logits, labels):
"""Specify the metrics for eval.
Args:
logits: Logits output from the graph.
labels: Ground truth labels for inputs.
Returns:
Eval Op for the graph.
"""
labels = tf.squeeze(labels)
names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
'Accuracy': tf.metrics.accuracy(tf.argmax(logits, 1), labels),
'Recall_5': tf.metrics.recall_at_k(labels, logits, 5),
})
for name, value in names_to_values.iteritems():
slim.summaries.add_scalar_summary(
value, name, prefix='eval', print_summary=True)
return names_to_updates.values()
def build_model():
"""Build the mobilenet_v1 model for evaluation.
Returns:
g: graph with rewrites after insertion of quantization ops and batch norm
folding.
eval_ops: eval ops for inference.
variables_to_restore: List of variables to restore from checkpoint.
"""
g = tf.Graph()
with g.as_default():
inputs, labels = imagenet_input(is_training=False)
scope = mobilenet_v1.mobilenet_v1_arg_scope(
is_training=False, weight_decay=0.0)
with slim.arg_scope(scope):
logits, _ = mobilenet_v1.mobilenet_v1(
inputs,
is_training=False,
depth_multiplier=FLAGS.depth_multiplier,
num_classes=FLAGS.num_classes)
if FLAGS.quantize:
tf.contrib.quantize.create_eval_graph()
eval_ops = metrics(logits, labels)
return g, eval_ops
def eval_model():
"""Evaluates mobilenet_v1."""
g, eval_ops = build_model()
with g.as_default():
num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
slim.evaluation.evaluate_once(
FLAGS.master,
FLAGS.checkpoint_dir,
logdir=FLAGS.eval_dir,
num_evals=num_batches,
eval_op=eval_ops)
def main(unused_arg):
eval_model()
if __name__ == '__main__':
tf.app.run(main)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build and train mobilenet_v1 with options for quantization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from datasets import dataset_factory
from nets import mobilenet_v1
from preprocessing import preprocessing_factory
slim = tf.contrib.slim
flags = tf.app.flags
flags.DEFINE_string('master', '', 'Session master')
flags.DEFINE_integer('task', 0, 'Task')
flags.DEFINE_integer('ps_tasks', 0, 'Number of ps')
flags.DEFINE_integer('batch_size', 64, 'Batch size')
flags.DEFINE_integer('num_classes', 1001, 'Number of classes to distinguish')
flags.DEFINE_integer('number_of_steps', None,
'Number of training steps to perform before stopping')
flags.DEFINE_integer('image_size', 224, 'Input image resolution')
flags.DEFINE_float('depth_multiplier', 1.0, 'Depth multiplier for mobilenet')
flags.DEFINE_bool('quantize', False, 'Quantize training')
flags.DEFINE_string('fine_tune_checkpoint', '',
'Checkpoint from which to start finetuning.')
flags.DEFINE_string('logdir', '', 'Directory for writing training event logs')
flags.DEFINE_string('dataset_dir', '', 'Location of dataset')
flags.DEFINE_integer('log_every_n_steps', 100, 'Number of steps per log')
flags.DEFINE_integer('save_summaries_secs', 100,
'How often to save summaries, secs')
flags.DEFINE_integer('save_interval_secs', 100,
'How often to save checkpoints, secs')
FLAGS = flags.FLAGS
_LEARNING_RATE_DECAY_FACTOR = 0.94
def get_learning_rate():
if FLAGS.fine_tune_checkpoint:
# If we are fine tuning a checkpoint we need to start at a lower learning
# rate since we are farther along on training.
return 1e-4
else:
return 0.045
def get_quant_delay():
if FLAGS.fine_tune_checkpoint:
# We can start quantizing immediately if we are finetuning.
return 0
else:
# We need to wait for the model to train a bit before we quantize if we are
# training from scratch.
return 250000
def imagenet_input(is_training):
"""Data reader for imagenet.
Reads in imagenet data and performs pre-processing on the images.
Args:
is_training: bool specifying if train or validation dataset is needed.
Returns:
A batch of images and labels.
"""
if is_training:
dataset = dataset_factory.get_dataset('imagenet', 'train',
FLAGS.dataset_dir)
else:
dataset = dataset_factory.get_dataset('imagenet', 'validation',
FLAGS.dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
shuffle=is_training,
common_queue_capacity=2 * FLAGS.batch_size,
common_queue_min=FLAGS.batch_size)
[image, label] = provider.get(['image', 'label'])
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
'mobilenet_v1', is_training=is_training)
image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size)
images, labels = tf.train.batch(
[image, label],
batch_size=FLAGS.batch_size,
num_threads=4,
capacity=5 * FLAGS.batch_size)
labels = slim.one_hot_encoding(labels, FLAGS.num_classes)
return images, labels
def build_model():
"""Builds graph for model to train with rewrites for quantization.
Returns:
g: Graph with fake quantization ops and batch norm folding suitable for
training quantized weights.
train_tensor: Train op for execution during training.
"""
g = tf.Graph()
with g.as_default(), tf.device(
tf.train.replica_device_setter(FLAGS.ps_tasks)):
inputs, labels = imagenet_input(is_training=True)
with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=True)):
logits, _ = mobilenet_v1.mobilenet_v1(
inputs,
is_training=True,
depth_multiplier=FLAGS.depth_multiplier,
num_classes=FLAGS.num_classes)
tf.losses.softmax_cross_entropy(labels, logits)
# Call rewriter to produce graph with fake quant ops and folded batch norms
# quant_delay delays start of quantization till quant_delay steps, allowing
# for better model accuracy.
if FLAGS.quantize:
tf.contrib.quantize.create_training_graph(quant_delay=get_quant_delay())
total_loss = tf.losses.get_total_loss(name='total_loss')
# Configure the learning rate using an exponential decay.
num_epochs_per_decay = 2.5
imagenet_size = 1271167
decay_steps = int(imagenet_size / FLAGS.batch_size * num_epochs_per_decay)
learning_rate = tf.train.exponential_decay(
get_learning_rate(),
tf.train.get_or_create_global_step(),
decay_steps,
_LEARNING_RATE_DECAY_FACTOR,
staircase=True)
opt = tf.train.GradientDescentOptimizer(learning_rate)
train_tensor = slim.learning.create_train_op(
total_loss,
optimizer=opt)
slim.summaries.add_scalar_summary(total_loss, 'total_loss', 'losses')
slim.summaries.add_scalar_summary(learning_rate, 'learning_rate', 'training')
return g, train_tensor
def get_checkpoint_init_fn():
"""Returns the checkpoint init_fn if the checkpoint is provided."""
if FLAGS.fine_tune_checkpoint:
variables_to_restore = slim.get_variables_to_restore()
global_step_reset = tf.assign(tf.train.get_or_create_global_step(), 0)
# When restoring from a floating point model, the min/max values for
# quantized weights and activations are not present.
# We instruct slim to ignore variables that are missing during restoration
# by setting ignore_missing_vars=True
slim_init_fn = slim.assign_from_checkpoint_fn(
FLAGS.fine_tune_checkpoint,
variables_to_restore,
ignore_missing_vars=True)
def init_fn(sess):
slim_init_fn(sess)
# If we are restoring from a floating point model, we need to initialize
# the global step to zero for the exponential decay to result in
# reasonable learning rates.
sess.run(global_step_reset)
return init_fn
else:
return None
def train_model():
"""Trains mobilenet_v1."""
g, train_tensor = build_model()
with g.as_default():
slim.learning.train(
train_tensor,
FLAGS.logdir,
is_chief=(FLAGS.task == 0),
master=FLAGS.master,
log_every_n_steps=FLAGS.log_every_n_steps,
graph=g,
number_of_steps=FLAGS.number_of_steps,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs,
init_fn=get_checkpoint_init_fn(),
global_step=tf.train.get_global_step())
def main(unused_arg):
train_model()
if __name__ == '__main__':
tf.app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment