"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "18efb5e8e0ed467f6dc42680d88787f5ed6c074e"
Commit 02732868 authored by shivaniag's avatar shivaniag Committed by Neal Wu
Browse files

Dataset input pipeline for imagenet (#2466)

parent 33ca217f
...@@ -22,7 +22,6 @@ import os ...@@ -22,7 +22,6 @@ import os
import tensorflow as tf import tensorflow as tf
import imagenet
import resnet_model import resnet_model
import vgg_preprocessing import vgg_preprocessing
...@@ -49,11 +48,12 @@ parser.add_argument( ...@@ -49,11 +48,12 @@ parser.add_argument(
help='The number of training steps to run between evaluations.') help='The number of training steps to run between evaluations.')
parser.add_argument( parser.add_argument(
'--train_batch_size', type=int, default=32, help='Batch size for training.') '--batch_size', type=int, default=32,
help='Batch size for training and evaluation.')
parser.add_argument( parser.add_argument(
'--eval_batch_size', type=int, default=100, '--map_threads', type=int, default=5,
help='Batch size for evaluation.') help='The number of threads for dataset.map.')
parser.add_argument( parser.add_argument(
'--first_cycle_steps', type=int, default=None, '--first_cycle_steps', type=int, default=None,
...@@ -61,51 +61,104 @@ parser.add_argument( ...@@ -61,51 +61,104 @@ parser.add_argument(
'you have stopped partway through a training cycle.') 'you have stopped partway through a training cycle.')
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
_EVAL_STEPS = 50000 // FLAGS.eval_batch_size
# Scale the learning rate linearly with the batch size. When the batch size is # Scale the learning rate linearly with the batch size. When the batch size is
# 256, the learning rate should be 0.1. # 256, the learning rate should be 0.1.
_INITIAL_LEARNING_RATE = 0.1 * FLAGS.train_batch_size / 256 _INITIAL_LEARNING_RATE = 0.1 * FLAGS.batch_size / 256
_NUM_CHANNELS = 3
_LABEL_CLASSES = 1001
_MOMENTUM = 0.9 _MOMENTUM = 0.9
_WEIGHT_DECAY = 1e-4 _WEIGHT_DECAY = 1e-4
train_dataset = imagenet.get_split('train', FLAGS.data_dir) _NUM_IMAGES = {
eval_dataset = imagenet.get_split('validation', FLAGS.data_dir) 'train': 1281167,
'validation': 50000,
}
image_preprocessing_fn = vgg_preprocessing.preprocess_image image_preprocessing_fn = vgg_preprocessing.preprocess_image
network = resnet_model.resnet_v2( network = resnet_model.resnet_v2(
resnet_size=FLAGS.resnet_size, num_classes=train_dataset.num_classes) resnet_size=FLAGS.resnet_size, num_classes=_LABEL_CLASSES)
batches_per_epoch = train_dataset.num_samples / FLAGS.train_batch_size batches_per_epoch = _NUM_IMAGES['train'] / FLAGS.batch_size
def input_fn(is_training): def filenames(is_training):
"""Input function which provides a single batch for train or eval.""" """Return filenames for dataset."""
batch_size = FLAGS.train_batch_size if is_training else FLAGS.eval_batch_size if is_training:
dataset = train_dataset if is_training else eval_dataset return [
capacity_multiplier = 20 if is_training else 2 os.path.join(FLAGS.data_dir, 'train-%05d-of-01024' % i)
min_multiplier = 10 if is_training else 1 for i in xrange(0, 1024)]
else:
return [
os.path.join(FLAGS.data_dir, 'validation-%05d-of-00128' % i)
for i in xrange(0, 128)]
def dataset_parser(value, is_training):
"""Parse an Imagenet record from value."""
keys_to_features = {
'image/encoded':
tf.FixedLenFeature((), tf.string, default_value=''),
'image/format':
tf.FixedLenFeature((), tf.string, default_value='jpeg'),
'image/class/label':
tf.FixedLenFeature([], dtype=tf.int64, default_value=-1),
'image/class/text':
tf.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/object/bbox/xmin':
tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymin':
tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmax':
tf.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymax':
tf.VarLenFeature(dtype=tf.float32),
'image/object/class/label':
tf.VarLenFeature(dtype=tf.int64),
}
provider = tf.contrib.slim.dataset_data_provider.DatasetDataProvider( parsed = tf.parse_single_example(value, keys_to_features)
dataset=dataset,
num_readers=4,
common_queue_capacity=capacity_multiplier * batch_size,
common_queue_min=min_multiplier * batch_size)
image, label = provider.get(['image', 'label']) image = tf.image.decode_image(
tf.reshape(parsed['image/encoded'], shape=[]),
_NUM_CHANNELS)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = image_preprocessing_fn(image=image, image = image_preprocessing_fn(
image=image,
output_height=network.default_image_size, output_height=network.default_image_size,
output_width=network.default_image_size, output_width=network.default_image_size,
is_training=is_training) is_training=is_training)
images, labels = tf.train.batch(tensors=[image, label], label = tf.cast(
batch_size=batch_size, tf.reshape(parsed['image/class/label'], shape=[]),
num_threads=4, dtype=tf.int32)
capacity=5 * batch_size)
return image, tf.one_hot(label, _LABEL_CLASSES)
def input_fn(is_training):
"""Input function which provides a single batch for train or eval."""
dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames(is_training))
if is_training:
dataset = dataset.shuffle(buffer_size=1024)
dataset = dataset.flat_map(tf.contrib.data.TFRecordDataset)
if is_training:
dataset = dataset.repeat()
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
num_threads=FLAGS.map_threads,
output_buffer_size=FLAGS.batch_size)
if is_training:
buffer_size = 1250 + 2 * FLAGS.batch_size
dataset = dataset.shuffle(buffer_size=buffer_size)
labels = tf.one_hot(labels, imagenet._NUM_CLASSES) iterator = dataset.batch(FLAGS.batch_size).make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels return images, labels
...@@ -204,8 +257,7 @@ def main(unused_argv): ...@@ -204,8 +257,7 @@ def main(unused_argv):
FLAGS.first_cycle_steps = None FLAGS.first_cycle_steps = None
print('Starting to evaluate.') print('Starting to evaluate.')
eval_results = resnet_classifier.evaluate( eval_results = resnet_classifier.evaluate(input_fn=lambda: input_fn(False))
input_fn=lambda: input_fn(False), steps=_EVAL_STEPS)
print(eval_results) print(eval_results)
......
...@@ -127,10 +127,10 @@ class BaseTest(tf.test.TestCase): ...@@ -127,10 +127,10 @@ class BaseTest(tf.test.TestCase):
def input_fn(self): def input_fn(self):
"""Provides random features and labels.""" """Provides random features and labels."""
features = tf.random_uniform([FLAGS.train_batch_size, 224, 224, 3]) features = tf.random_uniform([FLAGS.batch_size, 224, 224, 3])
labels = tf.one_hot( labels = tf.one_hot(
tf.random_uniform( tf.random_uniform(
[FLAGS.train_batch_size], maxval=_LABEL_CLASSES - 1, [FLAGS.batch_size], maxval=_LABEL_CLASSES - 1,
dtype=tf.int32), dtype=tf.int32),
_LABEL_CLASSES) _LABEL_CLASSES)
...@@ -145,9 +145,9 @@ class BaseTest(tf.test.TestCase): ...@@ -145,9 +145,9 @@ class BaseTest(tf.test.TestCase):
predictions = spec.predictions predictions = spec.predictions
self.assertAllEqual(predictions['probabilities'].shape, self.assertAllEqual(predictions['probabilities'].shape,
(FLAGS.train_batch_size, _LABEL_CLASSES)) (FLAGS.batch_size, _LABEL_CLASSES))
self.assertEqual(predictions['probabilities'].dtype, tf.float32) self.assertEqual(predictions['probabilities'].dtype, tf.float32)
self.assertAllEqual(predictions['classes'].shape, (FLAGS.train_batch_size,)) self.assertAllEqual(predictions['classes'].shape, (FLAGS.batch_size,))
self.assertEqual(predictions['classes'].dtype, tf.int64) self.assertEqual(predictions['classes'].dtype, tf.int64)
if mode != tf.estimator.ModeKeys.PREDICT: if mode != tf.estimator.ModeKeys.PREDICT:
......
...@@ -333,7 +333,8 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes, ...@@ -333,7 +333,8 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes,
inputs=inputs, pool_size=7, strides=1, padding='VALID', inputs=inputs, pool_size=7, strides=1, padding='VALID',
data_format=data_format) data_format=data_format)
inputs = tf.identity(inputs, 'final_avg_pool') inputs = tf.identity(inputs, 'final_avg_pool')
inputs = tf.reshape(inputs, [inputs.get_shape()[0].value, -1]) inputs = tf.reshape(inputs,
[-1, 512 if block_fn is building_block else 2048])
inputs = tf.layers.dense(inputs=inputs, units=num_classes) inputs = tf.layers.dense(inputs=inputs, units=num_classes)
inputs = tf.identity(inputs, 'final_dense') inputs = tf.identity(inputs, 'final_dense')
return inputs return inputs
......
...@@ -34,9 +34,9 @@ from __future__ import print_function ...@@ -34,9 +34,9 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
_R_MEAN = 123.68 _R_MEAN = 123.68 / 255
_G_MEAN = 116.78 _G_MEAN = 116.78 / 255
_B_MEAN = 103.94 _B_MEAN = 103.94 / 255
_RESIZE_SIDE_MIN = 256 _RESIZE_SIDE_MIN = 256
_RESIZE_SIDE_MAX = 512 _RESIZE_SIDE_MAX = 512
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment