Commit 65fad62d authored by Nathan Silberman's avatar Nathan Silberman
Browse files

Full code refactor and added all networks

parent bc0a0a86
# Description: # Description:
# Contains files for loading, training and evaluating TF-Slim 2.0-based models. # Contains files for loading, training and evaluating TF-Slim-based models.
package(default_visibility = [":internal"]) package(default_visibility = [":internal"])
...@@ -7,34 +7,41 @@ licenses(["notice"]) # Apache 2.0 ...@@ -7,34 +7,41 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"]) exports_files(["LICENSE"])
package_group( package_group(name = "internal")
name = "internal",
packages = ["//slim/"],
)
py_library( py_library(
name = "dataset_utils", name = "dataset_utils",
srcs = ["datasets/dataset_utils.py"], srcs = ["datasets/dataset_utils.py"],
) )
py_binary( py_library(
name = "download_and_convert_cifar10", name = "download_and_convert_cifar10",
srcs = ["datasets/download_and_convert_cifar10.py"], srcs = ["datasets/download_and_convert_cifar10.py"],
deps = [":dataset_utils"], deps = [":dataset_utils"],
) )
py_binary( py_library(
name = "download_and_convert_flowers", name = "download_and_convert_flowers",
srcs = ["datasets/download_and_convert_flowers.py"], srcs = ["datasets/download_and_convert_flowers.py"],
deps = [":dataset_utils"], deps = [":dataset_utils"],
) )
py_binary( py_library(
name = "download_and_convert_mnist", name = "download_and_convert_mnist",
srcs = ["datasets/download_and_convert_mnist.py"], srcs = ["datasets/download_and_convert_mnist.py"],
deps = [":dataset_utils"], deps = [":dataset_utils"],
) )
py_binary(
name = "download_and_convert_data",
srcs = ["download_and_convert_data.py"],
deps = [
":download_and_convert_cifar10",
":download_and_convert_flowers",
":download_and_convert_mnist",
],
)
py_binary( py_binary(
name = "cifar10", name = "cifar10",
srcs = ["datasets/cifar10.py"], srcs = ["datasets/cifar10.py"],
...@@ -70,78 +77,261 @@ py_library( ...@@ -70,78 +77,261 @@ py_library(
], ],
) )
py_binary(
name = "eval",
srcs = ["eval.py"],
deps = [
":dataset_factory",
":model_deploy",
":model_factory",
":preprocessing_factory",
],
)
py_library( py_library(
name = "model_deploy", name = "model_deploy",
srcs = ["models/model_deploy.py"], srcs = ["deployment/model_deploy.py"],
) )
py_test( py_test(
name = "model_deploy_test", name = "model_deploy_test",
srcs = ["models/model_deploy_test.py"], srcs = ["deployment/model_deploy_test.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [":model_deploy"], deps = [":model_deploy"],
) )
py_library( py_library(
name = "cifar10_preprocessing", name = "cifarnet_preprocessing",
srcs = ["models/cifar10_preprocessing.py"], srcs = ["preprocessing/cifarnet_preprocessing.py"],
) )
py_library( py_library(
name = "inception_preprocessing", name = "inception_preprocessing",
srcs = ["models/inception_preprocessing.py"], srcs = ["preprocessing/inception_preprocessing.py"],
) )
py_library( py_library(
name = "lenet_preprocessing", name = "lenet_preprocessing",
srcs = ["models/lenet_preprocessing.py"], srcs = ["preprocessing/lenet_preprocessing.py"],
) )
py_library( py_library(
name = "vgg_preprocessing", name = "vgg_preprocessing",
srcs = ["models/vgg_preprocessing.py"], srcs = ["preprocessing/vgg_preprocessing.py"],
) )
py_library( py_library(
name = "preprocessing_factory", name = "preprocessing_factory",
srcs = ["models/preprocessing_factory.py"], srcs = ["preprocessing/preprocessing_factory.py"],
deps = [ deps = [
":cifar10_preprocessing", ":cifarnet_preprocessing",
":inception_preprocessing", ":inception_preprocessing",
":lenet_preprocessing", ":lenet_preprocessing",
":vgg_preprocessing", ":vgg_preprocessing",
], ],
) )
# Typical networks definitions.
py_library(
name = "nets",
deps = [
":alexnet",
":cifarnet",
":inception",
":lenet",
":overfeat",
":resnet_v1",
":resnet_v2",
":vgg",
],
)
py_library(
name = "alexnet",
srcs = ["nets/alexnet.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "alexnet_test",
size = "medium",
srcs = ["nets/alexnet_test.py"],
srcs_version = "PY2AND3",
deps = [":alexnet"],
)
py_library(
name = "cifarnet",
srcs = ["nets/cifarnet.py"],
)
py_library(
name = "inception",
srcs = ["nets/inception.py"],
srcs_version = "PY2AND3",
deps = [
":inception_resnet_v2",
":inception_v1",
":inception_v2",
":inception_v3",
],
)
py_library(
name = "inception_v1",
srcs = ["nets/inception_v1.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "inception_v2",
srcs = ["nets/inception_v2.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "inception_v3",
srcs = ["nets/inception_v3.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "inception_resnet_v2",
srcs = ["nets/inception_resnet_v2.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "inception_v1_test",
size = "large",
srcs = ["nets/inception_v1_test.py"],
shard_count = 3,
srcs_version = "PY2AND3",
deps = [":inception"],
)
py_test(
name = "inception_v2_test",
size = "large",
srcs = ["nets/inception_v2_test.py"],
shard_count = 3,
srcs_version = "PY2AND3",
deps = [":inception"],
)
py_test(
name = "inception_v3_test",
size = "large",
srcs = ["nets/inception_v3_test.py"],
shard_count = 3,
srcs_version = "PY2AND3",
deps = [":inception"],
)
py_test(
name = "inception_resnet_v2_test",
size = "large",
srcs = ["nets/inception_resnet_v2_test.py"],
shard_count = 3,
srcs_version = "PY2AND3",
deps = [":inception"],
)
py_library( py_library(
name = "lenet", name = "lenet",
srcs = ["nets/lenet.py"], srcs = ["nets/lenet.py"],
) )
py_library( py_library(
name = "model_factory", name = "overfeat",
srcs = ["models/model_factory.py"], srcs = ["nets/overfeat.py"],
deps = [":lenet"], srcs_version = "PY2AND3",
)
py_test(
name = "overfeat_test",
size = "medium",
srcs = ["nets/overfeat_test.py"],
srcs_version = "PY2AND3",
deps = [":overfeat"],
)
py_library(
name = "resnet_utils",
srcs = ["nets/resnet_utils.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "resnet_v1",
srcs = ["nets/resnet_v1.py"],
srcs_version = "PY2AND3",
deps = [
":resnet_utils",
],
)
py_test(
name = "resnet_v1_test",
size = "medium",
srcs = ["nets/resnet_v1_test.py"],
srcs_version = "PY2AND3",
deps = [":resnet_v1"],
)
py_library(
name = "resnet_v2",
srcs = ["nets/resnet_v2.py"],
srcs_version = "PY2AND3",
deps = [
":resnet_utils",
],
)
py_test(
name = "resnet_v2_test",
size = "medium",
srcs = ["nets/resnet_v2_test.py"],
srcs_version = "PY2AND3",
deps = [":resnet_v2"],
)
py_library(
name = "vgg",
srcs = ["nets/vgg.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "vgg_test",
size = "medium",
srcs = ["nets/vgg_test.py"],
srcs_version = "PY2AND3",
deps = [":vgg"],
)
py_library(
name = "nets_factory",
srcs = ["nets/nets_factory.py"],
deps = [":nets"],
)
py_test(
name = "nets_factory_test",
size = "medium",
srcs = ["nets/nets_factory_test.py"],
srcs_version = "PY2AND3",
deps = [":nets_factory"],
)
py_binary(
name = "train_image_classifier",
srcs = ["train_image_classifier.py"],
deps = [
":dataset_factory",
":model_deploy",
":nets_factory",
":preprocessing_factory",
],
) )
py_binary( py_binary(
name = "train", name = "eval_image_classifier",
srcs = ["train.py"], srcs = ["eval_image_classifier.py"],
deps = [ deps = [
":dataset_factory", ":dataset_factory",
":model_deploy", ":model_deploy",
":model_factory", ":nets_factory",
":preprocessing_factory", ":preprocessing_factory",
], ],
) )
This diff is collapsed.
...@@ -25,7 +25,7 @@ from __future__ import print_function ...@@ -25,7 +25,7 @@ from __future__ import print_function
import os import os
import tensorflow as tf import tensorflow as tf
from slim.datasets import dataset_utils from datasets import dataset_utils
slim = tf.contrib.slim slim = tf.contrib.slim
......
...@@ -18,10 +18,10 @@ from __future__ import absolute_import ...@@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from slim.datasets import cifar10 from datasets import cifar10
from slim.datasets import flowers from datasets import flowers
from slim.datasets import imagenet from datasets import imagenet
from slim.datasets import mnist from datasets import mnist
datasets_map = { datasets_map = {
'cifar10': cifar10, 'cifar10': cifar10,
......
...@@ -18,6 +18,10 @@ from __future__ import division ...@@ -18,6 +18,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import sys
import tarfile
from six.moves import urllib
import tensorflow as tf import tensorflow as tf
LABELS_FILENAME = 'labels.txt' LABELS_FILENAME = 'labels.txt'
...@@ -59,6 +63,27 @@ def image_to_tfexample(image_data, image_format, height, width, class_id): ...@@ -59,6 +63,27 @@ def image_to_tfexample(image_data, image_format, height, width, class_id):
})) }))
def download_and_uncompress_tarball(tarball_url, dataset_dir):
"""Downloads the `tarball_url` and uncompresses it locally.
Args:
tarball_url: The URL of a tarball file.
dataset_dir: The directory where the temporary files are stored.
"""
filename = tarball_url.split('/')[-1]
filepath = os.path.join(dataset_dir, filename)
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
def write_label_file(labels_to_class_names, dataset_dir, def write_label_file(labels_to_class_names, dataset_dir,
filename=LABELS_FILENAME): filename=LABELS_FILENAME):
"""Writes a file with the list of class names. """Writes a file with the list of class names.
......
...@@ -14,16 +14,13 @@ ...@@ -14,16 +14,13 @@
# ============================================================================== # ==============================================================================
r"""Downloads and converts cifar10 data to TFRecords of TF-Example protos. r"""Downloads and converts cifar10 data to TFRecords of TF-Example protos.
This script downloads the cifar10 data, uncompresses it, reads the files This module downloads the cifar10 data, uncompresses it, reads the files
that make up the cifar10 data and creates two TFRecord datasets: one for train that make up the cifar10 data and creates two TFRecord datasets: one for train
and one for test. Each TFRecord dataset is comprised of a set of TF-Example and one for test. Each TFRecord dataset is comprised of a set of TF-Example
protocol buffers, each of which contain a single image and label. protocol buffers, each of which contain a single image and label.
The script should take several minutes to run. The script should take several minutes to run.
Usage:
$ bazel build slim:download_and_convert_cifar10
$ .bazel-bin/slim/download_and_convert_cifar10 --dataset_dir=[DIRECTORY]
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -38,14 +35,7 @@ import numpy as np ...@@ -38,14 +35,7 @@ import numpy as np
from six.moves import urllib from six.moves import urllib
import tensorflow as tf import tensorflow as tf
from slim.datasets import dataset_utils from datasets import dataset_utils
tf.app.flags.DEFINE_string(
'dataset_dir',
None,
'The directory where the output TFRecords and temporary files are saved.')
FLAGS = tf.app.flags.FLAGS
# The URL where the CIFAR data can be downloaded. # The URL where the CIFAR data can be downloaded.
_DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' _DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
...@@ -115,16 +105,17 @@ def _add_to_tfrecord(filename, tfrecord_writer, offset=0): ...@@ -115,16 +105,17 @@ def _add_to_tfrecord(filename, tfrecord_writer, offset=0):
return offset + num_images return offset + num_images
def _get_output_filename(split_name): def _get_output_filename(dataset_dir, split_name):
"""Creates the output filename. """Creates the output filename.
Args: Args:
dataset_dir: The dataset directory where the dataset is stored.
split_name: The name of the train/test split. split_name: The name of the train/test split.
Returns: Returns:
An absolute file path. An absolute file path.
""" """
return '%s/cifar10_%s.tfrecord' % (FLAGS.dataset_dir, split_name) return '%s/cifar10_%s.tfrecord' % (dataset_dir, split_name)
def _download_and_uncompress_dataset(dataset_dir): def _download_and_uncompress_dataset(dataset_dir):
...@@ -162,39 +153,43 @@ def _clean_up_temporary_files(dataset_dir): ...@@ -162,39 +153,43 @@ def _clean_up_temporary_files(dataset_dir):
tf.gfile.DeleteRecursively(tmp_dir) tf.gfile.DeleteRecursively(tmp_dir)
def main(_): def run(dataset_dir):
if not FLAGS.dataset_dir: """Runs the download and conversion operation.
raise ValueError('You must supply the dataset directory with --dataset_dir')
if not tf.gfile.Exists(FLAGS.dataset_dir): Args:
tf.gfile.MakeDirs(FLAGS.dataset_dir) dataset_dir: The dataset directory where the dataset is stored.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
_download_and_uncompress_dataset(FLAGS.dataset_dir) training_filename = _get_output_filename(dataset_dir, 'train')
testing_filename = _get_output_filename(dataset_dir, 'test')
if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename):
print('Dataset files already exist. Exiting without re-creating them.')
return
dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)
# First, process the training data: # First, process the training data:
output_file = _get_output_filename('train') with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer:
with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer:
offset = 0 offset = 0
for i in range(_NUM_TRAIN_FILES): for i in range(_NUM_TRAIN_FILES):
filename = os.path.join(FLAGS.dataset_dir, filename = os.path.join(dataset_dir,
'cifar-10-batches-py', 'cifar-10-batches-py',
'data_batch_%d' % (i + 1)) # 1-indexed. 'data_batch_%d' % (i + 1)) # 1-indexed.
offset = _add_to_tfrecord(filename, tfrecord_writer, offset) offset = _add_to_tfrecord(filename, tfrecord_writer, offset)
# Next, process the testing data: # Next, process the testing data:
output_file = _get_output_filename('test') with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer:
with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer: filename = os.path.join(dataset_dir,
filename = os.path.join(FLAGS.dataset_dir,
'cifar-10-batches-py', 'cifar-10-batches-py',
'test_batch') 'test_batch')
_add_to_tfrecord(filename, tfrecord_writer) _add_to_tfrecord(filename, tfrecord_writer)
# Finally, write the labels file: # Finally, write the labels file:
labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES)) labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
dataset_utils.write_label_file(labels_to_class_names, FLAGS.dataset_dir) dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
_clean_up_temporary_files(FLAGS.dataset_dir) _clean_up_temporary_files(dataset_dir)
print('\nFinished converting the Cifar10 dataset!') print('\nFinished converting the Cifar10 dataset!')
if __name__ == '__main__':
tf.app.run()
...@@ -14,17 +14,13 @@ ...@@ -14,17 +14,13 @@
# ============================================================================== # ==============================================================================
r"""Downloads and converts Flowers data to TFRecords of TF-Example protos. r"""Downloads and converts Flowers data to TFRecords of TF-Example protos.
This script downloads the Flowers data, uncompresses it, reads the files This module downloads the Flowers data, uncompresses it, reads the files
that make up the Flowers data and creates two TFRecord datasets: one for train that make up the Flowers data and creates two TFRecord datasets: one for train
and one for test. Each TFRecord dataset is comprised of a set of TF-Example and one for test. Each TFRecord dataset is comprised of a set of TF-Example
protocol buffers, each of which contain a single image and label. protocol buffers, each of which contain a single image and label.
The script should take about a minute to run. The script should take about a minute to run.
Usage:
$ bazel build slim:download_and_convert_flowers
$ .bazel-bin/slim/download_and_convert_flowers --dataset_dir=[DIRECTORY]
""" """
from __future__ import absolute_import from __future__ import absolute_import
...@@ -35,19 +31,10 @@ import math ...@@ -35,19 +31,10 @@ import math
import os import os
import random import random
import sys import sys
import tarfile
from six.moves import urllib
import tensorflow as tf import tensorflow as tf
from slim.datasets import dataset_utils from datasets import dataset_utils
tf.app.flags.DEFINE_string(
'dataset_dir',
None,
'The directory where the output TFRecords and temporary files are saved.')
FLAGS = tf.app.flags.FLAGS
# The URL where the Flowers data can be downloaded. # The URL where the Flowers data can be downloaded.
_DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz' _DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
...@@ -82,27 +69,6 @@ class ImageReader(object): ...@@ -82,27 +69,6 @@ class ImageReader(object):
return image return image
def _download_dataset(dataset_dir):
"""Downloads the flowers data and uncompresses it locally.
Args:
dataset_dir: The directory where the temporary files are stored.
"""
filename = _DATA_URL.split('/')[-1]
filepath = os.path.join(dataset_dir, filename)
if not os.path.exists(filepath):
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(_DATA_URL, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
def _get_filenames_and_classes(dataset_dir): def _get_filenames_and_classes(dataset_dir):
"""Returns a list of filenames and inferred class names. """Returns a list of filenames and inferred class names.
...@@ -132,6 +98,12 @@ def _get_filenames_and_classes(dataset_dir): ...@@ -132,6 +98,12 @@ def _get_filenames_and_classes(dataset_dir):
return photo_filenames, sorted(class_names) return photo_filenames, sorted(class_names)
def _get_dataset_filename(dataset_dir, split_name, shard_id):
output_filename = 'flowers_%s_%05d-of-%05d.tfrecord' % (
split_name, shard_id, _NUM_SHARDS)
return os.path.join(dataset_dir, output_filename)
def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir): def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
"""Converts the given filenames to a TFRecord dataset. """Converts the given filenames to a TFRecord dataset.
...@@ -152,9 +124,8 @@ def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir): ...@@ -152,9 +124,8 @@ def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
with tf.Session('') as sess: with tf.Session('') as sess:
for shard_id in range(_NUM_SHARDS): for shard_id in range(_NUM_SHARDS):
output_filename = 'flowers_%s_%05d-of-%05d.tfrecord' % ( output_filename = _get_dataset_filename(
split_name, shard_id, _NUM_SHARDS) dataset_dir, split_name, shard_id)
output_filename = os.path.join(dataset_dir, output_filename)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
start_ndx = shard_id * num_per_shard start_ndx = shard_id * num_per_shard
...@@ -193,15 +164,31 @@ def _clean_up_temporary_files(dataset_dir): ...@@ -193,15 +164,31 @@ def _clean_up_temporary_files(dataset_dir):
tf.gfile.DeleteRecursively(tmp_dir) tf.gfile.DeleteRecursively(tmp_dir)
def main(_): def _dataset_exists(dataset_dir):
if not FLAGS.dataset_dir: for split_name in ['train', 'validation']:
raise ValueError('You must supply the dataset directory with --dataset_dir') for shard_id in range(_NUM_SHARDS):
output_filename = _get_dataset_filename(
dataset_dir, split_name, shard_id)
if not tf.gfile.Exists(output_filename):
return False
return True
def run(dataset_dir):
"""Runs the download and conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
if not tf.gfile.Exists(FLAGS.dataset_dir): if _dataset_exists(dataset_dir):
tf.gfile.MakeDirs(FLAGS.dataset_dir) print('Dataset files already exist. Exiting without re-creating them.')
return
_download_dataset(FLAGS.dataset_dir) dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)
photo_filenames, class_names = _get_filenames_and_classes(FLAGS.dataset_dir) photo_filenames, class_names = _get_filenames_and_classes(dataset_dir)
class_names_to_ids = dict(zip(class_names, range(len(class_names)))) class_names_to_ids = dict(zip(class_names, range(len(class_names))))
# Divide into train and test: # Divide into train and test:
...@@ -212,16 +199,14 @@ def main(_): ...@@ -212,16 +199,14 @@ def main(_):
# First, convert the training and validation sets. # First, convert the training and validation sets.
_convert_dataset('train', training_filenames, class_names_to_ids, _convert_dataset('train', training_filenames, class_names_to_ids,
FLAGS.dataset_dir) dataset_dir)
_convert_dataset('validation', validation_filenames, class_names_to_ids, _convert_dataset('validation', validation_filenames, class_names_to_ids,
FLAGS.dataset_dir) dataset_dir)
# Finally, write the labels file: # Finally, write the labels file:
labels_to_class_names = dict(zip(range(len(class_names)), class_names)) labels_to_class_names = dict(zip(range(len(class_names)), class_names))
dataset_utils.write_label_file(labels_to_class_names, FLAGS.dataset_dir) dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
_clean_up_temporary_files(FLAGS.dataset_dir) _clean_up_temporary_files(dataset_dir)
print('\nFinished converting the Flowers dataset!') print('\nFinished converting the Flowers dataset!')
if __name__ == '__main__':
tf.app.run()
...@@ -14,17 +14,13 @@ ...@@ -14,17 +14,13 @@
# ============================================================================== # ==============================================================================
r"""Downloads and converts MNIST data to TFRecords of TF-Example protos. r"""Downloads and converts MNIST data to TFRecords of TF-Example protos.
This script downloads the MNIST data, uncompresses it, reads the files This module downloads the MNIST data, uncompresses it, reads the files
that make up the MNIST data and creates two TFRecord datasets: one for train that make up the MNIST data and creates two TFRecord datasets: one for train
and one for test. Each TFRecord dataset is comprised of a set of TF-Example and one for test. Each TFRecord dataset is comprised of a set of TF-Example
protocol buffers, each of which contain a single image and label. protocol buffers, each of which contain a single image and label.
The script should take about a minute to run. The script should take about a minute to run.
Usage:
$ bazel build slim:download_and_convert_mnist
$ .bazel-bin/slim/download_and_convert_mnist --dataset_dir=[DIRECTORY]
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -38,14 +34,7 @@ import numpy as np ...@@ -38,14 +34,7 @@ import numpy as np
from six.moves import urllib from six.moves import urllib
import tensorflow as tf import tensorflow as tf
from slim.datasets import dataset_utils from datasets import dataset_utils
tf.app.flags.DEFINE_string(
'dataset_dir',
None,
'The directory where the output TFRecords and temporary files are saved.')
FLAGS = tf.app.flags.FLAGS
# The URLs where the MNIST data can be downloaded. # The URLs where the MNIST data can be downloaded.
_DATA_URL = 'http://yann.lecun.com/exdb/mnist/' _DATA_URL = 'http://yann.lecun.com/exdb/mnist/'
...@@ -140,16 +129,17 @@ def _add_to_tfrecord(data_filename, labels_filename, num_images, ...@@ -140,16 +129,17 @@ def _add_to_tfrecord(data_filename, labels_filename, num_images,
tfrecord_writer.write(example.SerializeToString()) tfrecord_writer.write(example.SerializeToString())
def _get_output_filename(split_name): def _get_output_filename(dataset_dir, split_name):
"""Creates the output filename. """Creates the output filename.
Args: Args:
dataset_dir: The directory where the temporary files are stored.
split_name: The name of the train/test split. split_name: The name of the train/test split.
Returns: Returns:
An absolute file path. An absolute file path.
""" """
return '%s/mnist_%s.tfrecord' % (FLAGS.dataset_dir, split_name) return '%s/mnist_%s.tfrecord' % (dataset_dir, split_name)
def _download_dataset(dataset_dir): def _download_dataset(dataset_dir):
...@@ -193,35 +183,39 @@ def _clean_up_temporary_files(dataset_dir): ...@@ -193,35 +183,39 @@ def _clean_up_temporary_files(dataset_dir):
tf.gfile.Remove(filepath) tf.gfile.Remove(filepath)
def main(_): def run(dataset_dir):
if not FLAGS.dataset_dir: """Runs the download and conversion operation.
raise ValueError('You must supply the dataset directory with --dataset_dir')
if not tf.gfile.Exists(FLAGS.dataset_dir): Args:
tf.gfile.MakeDirs(FLAGS.dataset_dir) dataset_dir: The dataset directory where the dataset is stored.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
training_filename = _get_output_filename(dataset_dir, 'train')
testing_filename = _get_output_filename(dataset_dir, 'test')
_download_dataset(FLAGS.dataset_dir) if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename):
print('Dataset files already exist. Exiting without re-creating them.')
return
_download_dataset(dataset_dir)
# First, process the training data: # First, process the training data:
output_file = _get_output_filename('train') with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer:
with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer: data_filename = os.path.join(dataset_dir, _TRAIN_DATA_FILENAME)
data_filename = os.path.join(FLAGS.dataset_dir, _TRAIN_DATA_FILENAME) labels_filename = os.path.join(dataset_dir, _TRAIN_LABELS_FILENAME)
labels_filename = os.path.join(FLAGS.dataset_dir, _TRAIN_LABELS_FILENAME)
_add_to_tfrecord(data_filename, labels_filename, 60000, tfrecord_writer) _add_to_tfrecord(data_filename, labels_filename, 60000, tfrecord_writer)
# Next, process the testing data: # Next, process the testing data:
output_file = _get_output_filename('test') with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer:
with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer: data_filename = os.path.join(dataset_dir, _TEST_DATA_FILENAME)
data_filename = os.path.join(FLAGS.dataset_dir, _TEST_DATA_FILENAME) labels_filename = os.path.join(dataset_dir, _TEST_LABELS_FILENAME)
labels_filename = os.path.join(FLAGS.dataset_dir, _TEST_LABELS_FILENAME)
_add_to_tfrecord(data_filename, labels_filename, 10000, tfrecord_writer) _add_to_tfrecord(data_filename, labels_filename, 10000, tfrecord_writer)
# Finally, write the labels file: # Finally, write the labels file:
labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES)) labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
dataset_utils.write_label_file(labels_to_class_names, FLAGS.dataset_dir) dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
_clean_up_temporary_files(FLAGS.dataset_dir) _clean_up_temporary_files(dataset_dir)
print('\nFinished converting the MNIST dataset!') print('\nFinished converting the MNIST dataset!')
if __name__ == '__main__':
tf.app.run()
...@@ -25,7 +25,7 @@ from __future__ import print_function ...@@ -25,7 +25,7 @@ from __future__ import print_function
import os import os
import tensorflow as tf import tensorflow as tf
from slim.datasets import dataset_utils from datasets import dataset_utils
slim = tf.contrib.slim slim = tf.contrib.slim
......
...@@ -33,8 +33,11 @@ from __future__ import division ...@@ -33,8 +33,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
from six.moves import urllib
import tensorflow as tf import tensorflow as tf
from datasets import dataset_utils
slim = tf.contrib.slim slim = tf.contrib.slim
# TODO(nsilberman): Add tfrecord file type once the script is updated. # TODO(nsilberman): Add tfrecord file type once the script is updated.
...@@ -55,7 +58,61 @@ _ITEMS_TO_DESCRIPTIONS = { ...@@ -55,7 +58,61 @@ _ITEMS_TO_DESCRIPTIONS = {
_NUM_CLASSES = 1001 _NUM_CLASSES = 1001
# TODO(nsilberman): Add _LABELS_TO_NAMES
def create_readable_names_for_imagenet_labels():
"""Create a dict mapping label id to human readable string.
Returns:
labels_to_names: dictionary where keys are integers from to 1000
and values are human-readable names.
We retrieve a synset file, which contains a list of valid synset labels used
by ILSVRC competition. There is one synset one per line, eg.
# n01440764
# n01443537
We also retrieve a synset_to_human_file, which contains a mapping from synsets
to human-readable names for every synset in Imagenet. These are stored in a
tsv format, as follows:
# n02119247 black fox
# n02119359 silver fox
We assign each synset (in alphabetical order) an integer, starting from 1
(since 0 is reserved for the background class).
Code is based on
https://github.com/tensorflow/models/blob/master/inception/inception/data/build_imagenet_data.py#L463
"""
# pylint: disable=g-line-too-long
base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/inception/inception/data/'
synset_url = '{}/imagenet_lsvrc_2015_synsets.txt'.format(base_url)
synset_to_human_url = '{}/imagenet_metadata.txt'.format(base_url)
filename, _ = urllib.request.urlretrieve(synset_url)
synset_list = [s.strip() for s in open(filename).readlines()]
num_synsets_in_ilsvrc = len(synset_list)
assert num_synsets_in_ilsvrc == 1000
filename, _ = urllib.request.urlretrieve(synset_to_human_url)
synset_to_human_list = open(filename).readlines()
num_synsets_in_all_imagenet = len(synset_to_human_list)
assert num_synsets_in_all_imagenet == 21842
synset_to_human = {}
for s in synset_to_human_list:
parts = s.strip().split('\t')
assert len(parts) == 2
synset = parts[0]
human = parts[1]
synset_to_human[synset] = human
label_index = 1
labels_to_names = {0: 'background'}
for synset in synset_list:
name = synset_to_human[synset]
labels_to_names[label_index] = name
label_index += 1
return labels_to_names
def get_split(split_name, dataset_dir, file_pattern=None, reader=None): def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
...@@ -119,10 +176,18 @@ def get_split(split_name, dataset_dir, file_pattern=None, reader=None): ...@@ -119,10 +176,18 @@ def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
decoder = slim.tfexample_decoder.TFExampleDecoder( decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers) keys_to_features, items_to_handlers)
labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir)
else:
labels_to_names = create_readable_names_for_imagenet_labels()
dataset_utils.write_label_file(labels_to_names, dataset_dir)
return slim.dataset.Dataset( return slim.dataset.Dataset(
data_sources=file_pattern, data_sources=file_pattern,
reader=reader, reader=reader,
decoder=decoder, decoder=decoder,
num_samples=_SPLITS_TO_SIZES[split_name], num_samples=_SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
num_classes=_NUM_CLASSES) num_classes=_NUM_CLASSES,
labels_to_names=labels_to_names)
...@@ -25,7 +25,7 @@ from __future__ import print_function ...@@ -25,7 +25,7 @@ from __future__ import print_function
import os import os
import tensorflow as tf import tensorflow as tf
from slim.datasets import dataset_utils from datasets import dataset_utils
slim = tf.contrib.slim slim = tf.contrib.slim
......
...@@ -30,7 +30,7 @@ Usage: ...@@ -30,7 +30,7 @@ Usage:
g = tf.Graph() g = tf.Graph()
# Set up DeploymentConfig # Set up DeploymentConfig
config = slim.DeploymentConfig(num_clones=2, clone_on_cpu=True) config = model_deploy.DeploymentConfig(num_clones=2, clone_on_cpu=True)
# Create the global step on the device storing the variables. # Create the global step on the device storing the variables.
with tf.device(config.variables_device()): with tf.device(config.variables_device()):
...@@ -51,7 +51,8 @@ Usage: ...@@ -51,7 +51,8 @@ Usage:
predictions = CreateNetwork(images) predictions = CreateNetwork(images)
slim.losses.log_loss(predictions, labels) slim.losses.log_loss(predictions, labels)
model_dp = slim.deploy(config, model_fn, [inputs_queue], optimizer=optimizer) model_dp = model_deploy.deploy(config, model_fn, [inputs_queue],
optimizer=optimizer)
# Run training. # Run training.
slim.learning.train(model_dp.train_op, my_log_dir, slim.learning.train(model_dp.train_op, my_log_dir,
...@@ -240,7 +241,7 @@ def _gather_clone_loss(clone, num_clones, regularization_losses): ...@@ -240,7 +241,7 @@ def _gather_clone_loss(clone, num_clones, regularization_losses):
def _optimize_clone(optimizer, clone, num_clones, regularization_losses, def _optimize_clone(optimizer, clone, num_clones, regularization_losses,
kwargs=None): **kwargs):
"""Compute losses and gradients for a single clone. """Compute losses and gradients for a single clone.
Args: Args:
...@@ -249,7 +250,7 @@ def _optimize_clone(optimizer, clone, num_clones, regularization_losses, ...@@ -249,7 +250,7 @@ def _optimize_clone(optimizer, clone, num_clones, regularization_losses,
num_clones: The number of clones being deployed. num_clones: The number of clones being deployed.
regularization_losses: Possibly empty list of regularization_losses regularization_losses: Possibly empty list of regularization_losses
to add to the clone losses. to add to the clone losses.
kwargs: Dict of kwarg to pass to compute_gradients(). **kwargs: Dict of kwarg to pass to compute_gradients().
Returns: Returns:
A tuple (clone_loss, clone_grads_and_vars). A tuple (clone_loss, clone_grads_and_vars).
...@@ -267,7 +268,7 @@ def _optimize_clone(optimizer, clone, num_clones, regularization_losses, ...@@ -267,7 +268,7 @@ def _optimize_clone(optimizer, clone, num_clones, regularization_losses,
def optimize_clones(clones, optimizer, def optimize_clones(clones, optimizer,
regularization_losses=None, regularization_losses=None,
kwargs=None): **kwargs):
"""Compute clone losses and gradients for the given list of `Clones`. """Compute clone losses and gradients for the given list of `Clones`.
Note: The regularization_losses are added to the first clone losses. Note: The regularization_losses are added to the first clone losses.
...@@ -278,7 +279,7 @@ def optimize_clones(clones, optimizer, ...@@ -278,7 +279,7 @@ def optimize_clones(clones, optimizer,
regularization_losses: Optional list of regularization losses. If None it regularization_losses: Optional list of regularization losses. If None it
will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to
exclude them. exclude them.
kwargs: Optional list of keyword arguments to pass to `compute_gradients`. **kwargs: Optional list of keyword arguments to pass to `compute_gradients`.
Returns: Returns:
A tuple (total_loss, grads_and_vars). A tuple (total_loss, grads_and_vars).
...@@ -290,7 +291,6 @@ def optimize_clones(clones, optimizer, ...@@ -290,7 +291,6 @@ def optimize_clones(clones, optimizer,
""" """
grads_and_vars = [] grads_and_vars = []
clones_losses = [] clones_losses = []
kwargs = kwargs or {}
num_clones = len(clones) num_clones = len(clones)
if regularization_losses is None: if regularization_losses is None:
regularization_losses = tf.get_collection( regularization_losses = tf.get_collection(
...@@ -298,7 +298,7 @@ def optimize_clones(clones, optimizer, ...@@ -298,7 +298,7 @@ def optimize_clones(clones, optimizer,
for clone in clones: for clone in clones:
with tf.name_scope(clone.scope): with tf.name_scope(clone.scope):
clone_loss, clone_grad = _optimize_clone( clone_loss, clone_grad = _optimize_clone(
optimizer, clone, num_clones, regularization_losses, kwargs) optimizer, clone, num_clones, regularization_losses, **kwargs)
if clone_loss is not None: if clone_loss is not None:
clones_losses.append(clone_loss) clones_losses.append(clone_loss)
grads_and_vars.append(clone_grad) grads_and_vars.append(clone_grad)
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from slim.models import model_deploy from deployment import model_deploy
slim = tf.contrib.slim slim = tf.contrib.slim
......
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Downloads and converts a particular dataset.
Usage:
```shell
$ python download_and_convert_data.py \
--dataset_name=mnist \
--dataset_dir=/tmp/mnist
$ python download_and_convert_data.py \
--dataset_name=cifar10 \
--dataset_dir=/tmp/cifar10
$ python download_and_convert_data.py \
--dataset_name=flowers \
--dataset_dir=/tmp/flowers
```
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from datasets import download_and_convert_cifar10
from datasets import download_and_convert_flowers
from datasets import download_and_convert_mnist
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
'dataset_name',
None,
'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".')
tf.app.flags.DEFINE_string(
'dataset_dir',
None,
'The directory where the output TFRecords and temporary files are saved.')
def main(_):
if not FLAGS.dataset_name:
raise ValueError('You must supply the dataset name with --dataset_name')
if not FLAGS.dataset_dir:
raise ValueError('You must supply the dataset directory with --dataset_dir')
if FLAGS.dataset_name == 'cifar10':
download_and_convert_cifar10.run(FLAGS.dataset_dir)
elif FLAGS.dataset_name == 'flowers':
download_and_convert_flowers.run(FLAGS.dataset_dir)
elif FLAGS.dataset_name == 'mnist':
download_and_convert_mnist.run(FLAGS.dataset_dir)
else:
raise ValueError(
'dataset_name [%s] was not recognized.' % FLAGS.dataset_dir)
if __name__ == '__main__':
tf.app.run()
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Generic evaluation script that trains a given model a specified dataset.""" """Generic evaluation script that evaluates a model using a given dataset."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -21,9 +21,9 @@ from __future__ import print_function ...@@ -21,9 +21,9 @@ from __future__ import print_function
import math import math
import tensorflow as tf import tensorflow as tf
from slim.datasets import dataset_factory from datasets import dataset_factory
from slim.models import model_factory from nets import nets_factory
from slim.models import preprocessing_factory from preprocessing import preprocessing_factory
slim = tf.contrib.slim slim = tf.contrib.slim
...@@ -42,11 +42,6 @@ tf.app.flags.DEFINE_string( ...@@ -42,11 +42,6 @@ tf.app.flags.DEFINE_string(
'The directory where the model was written to or an absolute path to a ' 'The directory where the model was written to or an absolute path to a '
'checkpoint file.') 'checkpoint file.')
tf.app.flags.DEFINE_bool(
'restore_global_step', True,
'Whether or not to restore the global step. When evaluating a model '
'checkpoint containing ONLY weights, set this flag to `False`.')
tf.app.flags.DEFINE_string( tf.app.flags.DEFINE_string(
'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.') 'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.')
...@@ -58,11 +53,10 @@ tf.app.flags.DEFINE_string( ...@@ -58,11 +53,10 @@ tf.app.flags.DEFINE_string(
'dataset_name', 'imagenet', 'The name of the dataset to load.') 'dataset_name', 'imagenet', 'The name of the dataset to load.')
tf.app.flags.DEFINE_string( tf.app.flags.DEFINE_string(
'dataset_split_name', 'train', 'The name of the train/test split.') 'dataset_split_name', 'test', 'The name of the train/test split.')
tf.app.flags.DEFINE_string( tf.app.flags.DEFINE_string(
'dataset_dir', None, 'The directory where the dataset files are stored.') 'dataset_dir', None, 'The directory where the dataset files are stored.')
tf.app.flags.MarkFlagAsRequired('dataset_dir')
tf.app.flags.DEFINE_integer( tf.app.flags.DEFINE_integer(
'labels_offset', 0, 'labels_offset', 0,
...@@ -82,10 +76,17 @@ tf.app.flags.DEFINE_float( ...@@ -82,10 +76,17 @@ tf.app.flags.DEFINE_float(
'The decay to use for the moving average.' 'The decay to use for the moving average.'
'If left as None, then moving averages are not used.') 'If left as None, then moving averages are not used.')
tf.app.flags.DEFINE_integer(
'eval_image_size', None, 'Eval image size')
FLAGS = tf.app.flags.FLAGS FLAGS = tf.app.flags.FLAGS
def main(_): def main(_):
if not FLAGS.dataset_dir:
raise ValueError('You must supply the dataset directory with --dataset_dir')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default(): with tf.Graph().as_default():
tf_global_step = slim.get_or_create_global_step() tf_global_step = slim.get_or_create_global_step()
...@@ -98,7 +99,7 @@ def main(_): ...@@ -98,7 +99,7 @@ def main(_):
#################### ####################
# Select the model # # Select the model #
#################### ####################
model_fn = model_factory.get_model( network_fn = nets_factory.get_network_fn(
FLAGS.model_name, FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset), num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=False) is_training=False)
...@@ -122,9 +123,9 @@ def main(_): ...@@ -122,9 +123,9 @@ def main(_):
preprocessing_name, preprocessing_name,
is_training=False) is_training=False)
image = image_preprocessing_fn(image, eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size
height=model_fn.default_image_size,
width=model_fn.default_image_size) image = image_preprocessing_fn(image, eval_image_size, eval_image_size)
images, labels = tf.train.batch( images, labels = tf.train.batch(
[image, label], [image, label],
...@@ -135,19 +136,16 @@ def main(_): ...@@ -135,19 +136,16 @@ def main(_):
#################### ####################
# Define the model # # Define the model #
#################### ####################
logits, _ = model_fn(images) logits, _ = network_fn(images)
if FLAGS.moving_average_decay: if FLAGS.moving_average_decay:
variable_averages = tf.train.ExponentialMovingAverage( variable_averages = tf.train.ExponentialMovingAverage(
FLAGS.moving_average_decay, tf_global_step) FLAGS.moving_average_decay, tf_global_step)
variables_to_restore = variable_averages.variables_to_restore( variables_to_restore = variable_averages.variables_to_restore(
slim.get_model_variables()) slim.get_model_variables())
variables_to_restore[tf_global_step.op.name] = tf_global_step
if FLAGS.restore_global_step:
variables_to_restore[tf_global_step.op.name] = tf_global_step
else: else:
exclude = None if FLAGS.restore_global_step else ['global_step'] variables_to_restore = slim.get_variables_to_restore()
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
predictions = tf.argmax(logits, 1) predictions = tf.argmax(logits, 1)
labels = tf.squeeze(labels) labels = tf.squeeze(labels)
...@@ -181,8 +179,8 @@ def main(_): ...@@ -181,8 +179,8 @@ def main(_):
tf.logging.info('Evaluating %s' % checkpoint_path) tf.logging.info('Evaluating %s' % checkpoint_path)
slim.evaluation.evaluate_once( slim.evaluation.evaluate_once(
FLAGS.master, master=FLAGS.master,
checkpoint_path, checkpoint_path=checkpoint_path,
logdir=FLAGS.eval_dir, logdir=FLAGS.eval_dir,
num_evals=num_batches, num_evals=num_batches,
eval_op=names_to_updates.values(), eval_op=names_to_updates.values(),
......
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains a factory for building various models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.slim import nets
from slim.nets import lenet
slim = tf.contrib.slim
def get_model(name, num_classes, weight_decay=0.0, is_training=False):
"""Returns a model_fn such as `logits, end_points = model_fn(images)`.
Args:
name: The name of the model.
num_classes: The number of classes to use for classification.
weight_decay: The l2 coefficient for the model weights.
is_training: `True` if the model is being used for training and `False`
otherwise.
Returns:
model_fn: A function that applies the model to a batch of images. It has
the following signature:
logits, end_points = model_fn(images)
Raises:
ValueError: If model `name` is not recognized.
"""
if name == 'inception_v1':
default_image_size = nets.inception.inception_v1.default_image_size
def func(images):
with slim.arg_scope(nets.inception.inception_v1_arg_scope(
weight_decay=weight_decay)):
return nets.inception.inception_v1(images,
num_classes,
is_training=is_training)
model_fn = func
elif name == 'inception_v2':
default_image_size = nets.inception.inception_v2.default_image_size
def func(images):
with slim.arg_scope(nets.inception.inception_v2_arg_scope(
weight_decay=weight_decay)):
return nets.inception.inception_v2(images,
num_classes=num_classes,
is_training=is_training)
model_fn = func
elif name == 'inception_v3':
default_image_size = nets.inception.inception_v3.default_image_size
def func(images):
with slim.arg_scope(nets.inception.inception_v3_arg_scope(
weight_decay=weight_decay)):
return nets.inception.inception_v3(images,
num_classes=num_classes,
is_training=is_training)
model_fn = func
elif name == 'lenet':
default_image_size = lenet.lenet.default_image_size
def func(images):
with slim.arg_scope(lenet.lenet_arg_scope(weight_decay=weight_decay)):
return lenet.lenet(images,
num_classes=num_classes,
is_training=is_training)
model_fn = func
elif name == 'resnet_v1_50':
default_image_size = nets.resnet_v1.resnet_v1.default_image_size
def func(images):
with slim.arg_scope(nets.resnet_v1.resnet_arg_scope(
is_training, weight_decay=weight_decay)):
net, end_points = nets.resnet_v1.resnet_v1_50(
images, num_classes=num_classes)
net = tf.squeeze(net, squeeze_dims=[1, 2])
return net, end_points
model_fn = func
elif name == 'resnet_v1_101':
default_image_size = nets.resnet_v1.resnet_v1.default_image_size
def func(images):
with slim.arg_scope(nets.resnet_v1.resnet_arg_scope(
is_training, weight_decay=weight_decay)):
net, end_points = nets.resnet_v1.resnet_v1_101(
images, num_classes=num_classes)
net = tf.squeeze(net, squeeze_dims=[1, 2])
return net, end_points
model_fn = func
elif name == 'resnet_v1_152':
default_image_size = nets.resnet_v1.resnet_v1.default_image_size
def func(images):
with slim.arg_scope(nets.resnet_v1.resnet_arg_scope(
is_training, weight_decay=weight_decay)):
net, end_points = nets.resnet_v1.resnet_v1_152(
images, num_classes=num_classes)
net = tf.squeeze(net, squeeze_dims=[1, 2])
return net, end_points
model_fn = func
elif name == 'vgg_a':
default_image_size = nets.vgg.vgg_a.default_image_size
def func(images):
with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay)):
return nets.vgg.vgg_a(images,
num_classes=num_classes,
is_training=is_training)
model_fn = func
elif name == 'vgg_16':
default_image_size = nets.vgg.vgg_16.default_image_size
def func(images):
with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay)):
return nets.vgg.vgg_16(images,
num_classes=num_classes,
is_training=is_training)
model_fn = func
elif name == 'vgg_19':
default_image_size = nets.vgg.vgg_19.default_image_size
def func(images):
with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay)):
return nets.vgg.vgg_19(images,
num_classes=num_classes,
is_training=is_training)
model_fn = func
else:
raise ValueError('Model name [%s] was not recognized' % name)
model_fn.default_image_size = default_image_size
return model_fn
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides utilities to preprocess images for the ResNet networks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.slim import nets
from tensorflow.python.ops import control_flow_ops
slim = tf.contrib.slim
_R_MEAN = 123.68
_G_MEAN = 116.78
_B_MEAN = 103.94
_CROP_HEIGHT = nets.resnet_v1.resnet_v1.default_image_size
_CROP_WIDTH = nets.resnet_v1.resnet_v1.default_image_size
_RESIZE_SIDE = 256
def _mean_image_subtraction(image, means):
"""Subtracts the given means from each image channel.
For example:
means = [123.68, 116.779, 103.939]
image = _mean_image_subtraction(image, means)
Note that the rank of `image` must be known.
Args:
image: a tensor of size [height, width, C].
means: a C-vector of values to subtract from each channel.
Returns:
the centered image.
Raises:
ValueError: If the rank of `image` is unknown, if `image` has a rank other
than three or if the number of channels in `image` doesn't match the
number of values in `means`.
"""
if image.get_shape().ndims != 3:
raise ValueError('Input must be of size [height, width, C>0]')
num_channels = image.get_shape().as_list()[-1]
if len(means) != num_channels:
raise ValueError('len(means) must match the number of channels')
channels = tf.split(2, num_channels, image)
for i in range(num_channels):
channels[i] -= means[i]
return tf.concat(2, channels)
def _smallest_size_at_least(height, width, smallest_side):
"""Computes new shape with the smallest side equal to `smallest_side`.
Computes new shape with the smallest side equal to `smallest_side` while
preserving the original aspect ratio.
Args:
height: an int32 scalar tensor indicating the current height.
width: an int32 scalar tensor indicating the current width.
smallest_side: an python integer indicating the smallest side of the new
shape.
Returns:
new_height: an int32 scalar tensor indicating the new height.
new_width: and int32 scalar tensor indicating the new width.
"""
height = tf.to_float(height)
width = tf.to_float(width)
smallest_side = float(smallest_side)
scale = tf.cond(tf.greater(height, width),
lambda: smallest_side / width,
lambda: smallest_side / height)
new_height = tf.to_int32(height * scale)
new_width = tf.to_int32(width * scale)
return new_height, new_width
def _aspect_preserving_resize(image, smallest_side):
"""Resize images preserving the original aspect ratio.
Args:
image: a 3-D image tensor.
smallest_side: a python integer indicating the size of the smallest side
after resize.
Returns:
resized_image: a 3-D tensor containing the resized image.
"""
shape = tf.shape(image)
height = shape[0]
width = shape[1]
new_height, new_width = _smallest_size_at_least(height, width, smallest_side)
image = tf.expand_dims(image, 0)
resized_image = tf.image.resize_bilinear(image, [new_height, new_width],
align_corners=False)
resized_image = tf.squeeze(resized_image)
resized_image.set_shape([None, None, 3])
return resized_image
def _crop(image, offset_height, offset_width, crop_height, crop_width):
"""Crops the given image using the provided offsets and sizes.
Note that the method doesn't assume we know the input image size but it does
assume we know the input image rank.
Args:
image: an image of shape [height, width, channels].
offset_height: a scalar tensor indicating the height offset.
offset_width: a scalar tensor indicating the width offset.
crop_height: the height of the cropped image.
crop_width: the width of the cropped image.
Returns:
the cropped (and resized) image.
Raises:
InvalidArgumentError: if the rank is not 3 or if the image dimensions are
less than the crop size.
"""
original_shape = tf.shape(image)
rank_assertion = tf.Assert(
tf.equal(tf.rank(image), 3),
['Rank of image must be equal to 3.'])
cropped_shape = control_flow_ops.with_dependencies(
[rank_assertion],
tf.pack([crop_height, crop_width, original_shape[2]]))
size_assertion = tf.Assert(
tf.logical_and(
tf.greater_equal(original_shape[0], crop_height),
tf.greater_equal(original_shape[1], crop_width)),
['Crop size greater than the image size.'])
offsets = tf.to_int32(tf.pack([offset_height, offset_width, 0]))
# Use tf.slice instead of crop_to_bounding box as it accepts tensors to
# define the crop size.
image = control_flow_ops.with_dependencies(
[size_assertion],
tf.slice(image, offsets, cropped_shape))
return tf.reshape(image, cropped_shape)
def _central_crop(image_list, crop_height, crop_width):
"""Performs central crops of the given image list.
Args:
image_list: a list of image tensors of the same dimension but possibly
varying channel.
crop_height: the height of the image following the crop.
crop_width: the width of the image following the crop.
Returns:
the list of cropped images.
"""
outputs = []
for image in image_list:
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
offset_height = (image_height - crop_height) / 2
offset_width = (image_width - crop_width) / 2
outputs.append(_crop(image, offset_height, offset_width,
crop_height, crop_width))
return outputs
def preprocess_image(image,
height=_CROP_HEIGHT,
width=_CROP_WIDTH,
is_training=False, # pylint: disable=unused-argument
resize_side=_RESIZE_SIDE):
image = _aspect_preserving_resize(image, resize_side)
image = _central_crop([image], height, width)[0]
image.set_shape([height, width, 3])
image = tf.to_float(image)
image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN])
return image
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment