"...text-generation-inference.git" did not exist on "b95732180dc52be869e8c3e752a9c54608a6c7a5"
Commit 04235f9b authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge pull request #1889 from mari-linhares/cifar10_experiment

Adding option to use experiments and Sync training with Cifar10_estimator
parents ed65b632 2f69dc64
...@@ -23,8 +23,9 @@ tar xzf cifar-10-binary.tar.gz ...@@ -23,8 +23,9 @@ tar xzf cifar-10-binary.tar.gz
$ ls -R cifar-10-batches-bin $ ls -R cifar-10-batches-bin
cifar-10-batches-bin: cifar-10-batches-bin:
batches.meta.txt data_batch_1.bin data_batch_2.bin data_batch_3.bin batches.meta data_batch_1 data_batch_2 data_batch_3
data_batch_4.bin data_batch_5.bin readme.html test_batch.bin data_batch_4 data_batch_5 readme.html test_batch
# Run the model on CPU only. After training, it runs the evaluation. # Run the model on CPU only. After training, it runs the evaluation.
$ python cifar10_main.py --data_dir=/prefix/to/downloaded/data/cifar-10-batches-bin \ $ python cifar10_main.py --data_dir=/prefix/to/downloaded/data/cifar-10-batches-bin \
......
...@@ -16,10 +16,8 @@ ...@@ -16,10 +16,8 @@
See http://www.cs.toronto.edu/~kriz/cifar.html. See http://www.cs.toronto.edu/~kriz/cifar.html.
""" """
import cPickle
import os import os
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
...@@ -28,6 +26,7 @@ HEIGHT = 32 ...@@ -28,6 +26,7 @@ HEIGHT = 32
WIDTH = 32 WIDTH = 32
DEPTH = 3 DEPTH = 3
class Cifar10DataSet(object): class Cifar10DataSet(object):
"""Cifar10 data set. """Cifar10 data set.
...@@ -38,30 +37,53 @@ class Cifar10DataSet(object): ...@@ -38,30 +37,53 @@ class Cifar10DataSet(object):
self.data_dir = data_dir self.data_dir = data_dir
self.subset = subset self.subset = subset
self.use_distortion = use_distortion self.use_distortion = use_distortion
def get_filenames(self): def get_filenames(self):
if self.subset == 'train': if self.subset == 'train':
return [ return [
os.path.join(self.data_dir, 'data_batch_%d.bin' % i) os.path.join(self.data_dir, 'data_batch_%d.tfrecords' % i)
for i in xrange(1, 5) for i in xrange(1, 5)
] ]
elif self.subset == 'validation': elif self.subset == 'validation':
return [os.path.join(self.data_dir, 'data_batch_5.bin')] return [os.path.join(self.data_dir, 'data_batch_5.tfrecords')]
elif self.subset == 'eval': elif self.subset == 'eval':
return [os.path.join(self.data_dir, 'test_batch.bin')] return [os.path.join(self.data_dir, 'test_batch.tfrecords')]
else: else:
raise ValueError('Invalid data subset "%s"' % self.subset) raise ValueError('Invalid data subset "%s"' % self.subset)
def parser(self, serialized_example):
"""Parses a single tf.Example into image and label tensors."""
# Dimensions of the images in the CIFAR-10 dataset.
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
features = tf.parse_single_example(
serialized_example,
features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image'], tf.uint8)
image.set_shape([DEPTH * HEIGHT * WIDTH])
# Reshape from [depth * height * width] to [depth, height, width].
image = tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0])
label = tf.cast(features['label'], tf.int32)
# Custom preprocessing .
image = self.preprocess(image)
return image, label
def make_batch(self, batch_size): def make_batch(self, batch_size):
"""Read the images and labels from 'filenames'.""" """Read the images and labels from 'filenames'."""
filenames = self.get_filenames() filenames = self.get_filenames()
record_bytes = (32 * 32 * 3) + 1
# Repeat infinitely. # Repeat infinitely.
dataset = tf.contrib.data.FixedLengthRecordDataset(filenames, dataset = tf.contrib.data.TFRecordDataset(filenames).repeat()
record_bytes).repeat()
# Parse records. # Parse records.
dataset = dataset.map(self.parser, num_threads=batch_size, dataset = dataset.map(self.parser, num_threads=batch_size,
output_buffer_size=2 * batch_size) output_buffer_size=2 * batch_size)
# Potentially shuffle records. # Potentially shuffle records.
if self.subset == 'train': if self.subset == 'train':
min_queue_examples = int( min_queue_examples = int(
...@@ -69,49 +91,13 @@ class Cifar10DataSet(object): ...@@ -69,49 +91,13 @@ class Cifar10DataSet(object):
# Ensure that the capacity is sufficiently large to provide good random # Ensure that the capacity is sufficiently large to provide good random
# shuffling. # shuffling.
dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size) dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)
# Batch it up. # Batch it up.
dataset = dataset.batch(batch_size) dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator() iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next() image_batch, label_batch = iterator.get_next()
return image_batch, label_batch
def parser(self, value):
"""Parse a Cifar10 record from value.
Output images are in [height, width, depth] layout.
"""
# Dimensions of the images in the CIFAR-10 dataset.
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
label_bytes = 1
image_bytes = HEIGHT * WIDTH * DEPTH
# Every record consists of a label followed by the image, with a
# fixed number of bytes for each.
record_bytes = label_bytes + image_bytes
# Convert from a string to a vector of uint8 that is record_bytes long.
record_as_bytes = tf.decode_raw(value, tf.uint8)
# The first bytes represent the label, which we convert from
# uint8->int32.
label = tf.cast(
tf.strided_slice(record_as_bytes, [0], [label_bytes]), tf.int32)
label.set_shape([1])
# The remaining bytes after the label represent the image, which
# we reshape from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(
tf.strided_slice(record_as_bytes, [label_bytes], [record_bytes]),
[3, 32, 32])
# Convert from [depth, height, width] to [height, width, depth].
# This puts data in a compatible layout with TF image preprocessing APIs.
image = tf.transpose(depth_major, [1, 2, 0])
# Do custom preprocessing here.
image = self.preprocess(image)
return image, label return image_batch, label_batch
def preprocess(self, image): def preprocess(self, image):
"""Preprocess a single image in [height, width, depth] layout.""" """Preprocess a single image in [height, width, depth] layout."""
......
...@@ -33,12 +33,16 @@ import functools ...@@ -33,12 +33,16 @@ import functools
import operator import operator
import os import os
import cifar10
import cifar10_model
import numpy as np import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
import cifar10
import cifar10_model
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
...@@ -50,7 +54,7 @@ tf.flags.DEFINE_string('data_dir', '', ...@@ -50,7 +54,7 @@ tf.flags.DEFINE_string('data_dir', '',
tf.flags.DEFINE_string('model_dir', '', tf.flags.DEFINE_string('model_dir', '',
'The directory where the model will be stored.') 'The directory where the model will be stored.')
tf.flags.DEFINE_boolean('is_cpu_ps', False, tf.flags.DEFINE_boolean('is_cpu_ps', True,
'If using CPU as the parameter server.') 'If using CPU as the parameter server.')
tf.flags.DEFINE_integer('num_gpus', 1, tf.flags.DEFINE_integer('num_gpus', 1,
...@@ -58,7 +62,7 @@ tf.flags.DEFINE_integer('num_gpus', 1, ...@@ -58,7 +62,7 @@ tf.flags.DEFINE_integer('num_gpus', 1,
tf.flags.DEFINE_integer('num_layers', 44, 'The number of layers of the model.') tf.flags.DEFINE_integer('num_layers', 44, 'The number of layers of the model.')
tf.flags.DEFINE_integer('train_steps', 10000, tf.flags.DEFINE_integer('train_steps', 80000,
'The number of steps to use for training.') 'The number of steps to use for training.')
tf.flags.DEFINE_integer('train_batch_size', 128, 'Batch size for training.') tf.flags.DEFINE_integer('train_batch_size', 128, 'Batch size for training.')
...@@ -67,11 +71,26 @@ tf.flags.DEFINE_integer('eval_batch_size', 100, 'Batch size for validation.') ...@@ -67,11 +71,26 @@ tf.flags.DEFINE_integer('eval_batch_size', 100, 'Batch size for validation.')
tf.flags.DEFINE_float('momentum', 0.9, 'Momentum for MomentumOptimizer.') tf.flags.DEFINE_float('momentum', 0.9, 'Momentum for MomentumOptimizer.')
tf.flags.DEFINE_float('weight_decay', 1e-4, 'Weight decay for convolutions.') tf.flags.DEFINE_float('weight_decay', 2e-4, 'Weight decay for convolutions.')
tf.flags.DEFINE_boolean('use_distortion_for_training', True, tf.flags.DEFINE_boolean('use_distortion_for_training', True,
'If doing image distortion for training.') 'If doing image distortion for training.')
tf.flags.DEFINE_boolean('run_experiment', False,
'If True will run an experiment,'
'otherwise will run training and evaluation'
'using the estimator interface.'
'Experiments perform training on several workers in'
'parallel, in other words experiments know how to'
' invoke train and eval in a sensible fashion for'
' distributed training.')
tf.flags.DEFINE_boolean('sync', False,
'If true when running in a distributed environment'
'will run on sync mode')
tf.flags.DEFINE_integer('num_workers', 1, 'Number of workers')
# Perf flags # Perf flags
tf.flags.DEFINE_integer('num_intra_threads', 1, tf.flags.DEFINE_integer('num_intra_threads', 1,
"""Number of threads to use for intra-op parallelism. """Number of threads to use for intra-op parallelism.
...@@ -98,6 +117,68 @@ tf.flags.DEFINE_boolean('log_device_placement', False, ...@@ -98,6 +117,68 @@ tf.flags.DEFINE_boolean('log_device_placement', False,
'Whether to log device placement.') 'Whether to log device placement.')
class ExamplesPerSecondHook(session_run_hook.SessionRunHook):
"""Hook to print out examples per second.
Total time is tracked and then divided by the total number of steps
to get the average step time and then batch_size is used to determine
the running average of examples per second. The examples per second for the
most recent interval is also logged.
"""
def __init__(
self,
batch_size,
every_n_steps=100,
every_n_secs=None,):
"""Initializer for ExamplesPerSecondHook.
Args:
batch_size: Total batch size used to calculate examples/second from
global time.
every_n_steps: Log stats every n steps.
every_n_secs: Log stats every n seconds.
"""
if (every_n_steps is None) == (every_n_secs is None):
raise ValueError('exactly one of every_n_steps'
' and every_n_secs should be provided.')
self._timer = basic_session_run_hooks.SecondOrStepTimer(
every_steps=every_n_steps, every_secs=every_n_secs)
self._step_train_time = 0
self._total_steps = 0
self._batch_size = batch_size
def begin(self):
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
'Global step should be created to use StepCounterHook.')
def before_run(self, run_context): # pylint: disable=unused-argument
return basic_session_run_hooks.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
_ = run_context
global_step = run_values.results
if self._timer.should_trigger_for_step(global_step):
elapsed_time, elapsed_steps = self._timer.update_last_triggered_step(
global_step)
if elapsed_time is not None:
steps_per_sec = elapsed_steps / elapsed_time
self._step_train_time += elapsed_time
self._total_steps += elapsed_steps
average_examples_per_sec = self._batch_size * (
self._total_steps / self._step_train_time)
current_examples_per_sec = steps_per_sec * self._batch_size
# Average examples/sec followed by current examples/sec
logging.info('%s: %g (%g), step = %g', 'Average examples/sec',
average_examples_per_sec, current_examples_per_sec,
self._total_steps)
class GpuParamServerDeviceSetter(object): class GpuParamServerDeviceSetter(object):
"""Used with tf.device() to place variables on the least loaded GPU. """Used with tf.device() to place variables on the least loaded GPU.
...@@ -109,6 +190,7 @@ class GpuParamServerDeviceSetter(object): ...@@ -109,6 +190,7 @@ class GpuParamServerDeviceSetter(object):
def __init__(self, worker_device, ps_devices): def __init__(self, worker_device, ps_devices):
"""Initializer for GpuParamServerDeviceSetter. """Initializer for GpuParamServerDeviceSetter.
Args: Args:
worker_device: the device to use for computation Ops. worker_device: the device to use for computation Ops.
ps_devices: a list of devices to use for Variable Ops. Each variable is ps_devices: a list of devices to use for Variable Ops. Each variable is
...@@ -123,7 +205,7 @@ class GpuParamServerDeviceSetter(object): ...@@ -123,7 +205,7 @@ class GpuParamServerDeviceSetter(object):
return op.device return op.device
if op.type not in ['Variable', 'VariableV2', 'VarHandleOp']: if op.type not in ['Variable', 'VariableV2', 'VarHandleOp']:
return self.worker_device return self.worker_device
# Gets the least loaded ps_device # Gets the least loaded ps_device
device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1)) device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index] device_name = self.ps_devices[device_index]
...@@ -132,6 +214,7 @@ class GpuParamServerDeviceSetter(object): ...@@ -132,6 +214,7 @@ class GpuParamServerDeviceSetter(object):
return device_name return device_name
def _create_device_setter(is_cpu_ps, worker, num_gpus): def _create_device_setter(is_cpu_ps, worker, num_gpus):
"""Create device setter object.""" """Create device setter object."""
if is_cpu_ps: if is_cpu_ps:
...@@ -204,7 +287,7 @@ def _resnet_model_fn(features, labels, mode): ...@@ -204,7 +287,7 @@ def _resnet_model_fn(features, labels, mode):
ps_device = '/cpu:0' if is_cpu_ps else '/gpu:0' ps_device = '/cpu:0' if is_cpu_ps else '/gpu:0'
with tf.device(ps_device): with tf.device(ps_device):
with tf.name_scope('gradient_averaging'): with tf.name_scope('gradient_averaging'):
loss = tf.reduce_mean(tower_losses) loss = tf.reduce_mean(tower_losses, name='loss')
for zipped_gradvars in zip(*tower_gradvars): for zipped_gradvars in zip(*tower_gradvars):
# Averaging one var's gradients computed from multiple towers # Averaging one var's gradients computed from multiple towers
var = zipped_gradvars[0][1] var = zipped_gradvars[0][1]
...@@ -220,7 +303,7 @@ def _resnet_model_fn(features, labels, mode): ...@@ -220,7 +303,7 @@ def _resnet_model_fn(features, labels, mode):
# https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/cifar10-resnet.py#L155 # https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/cifar10-resnet.py#L155
# users could apply other scheduling. # users could apply other scheduling.
num_batches_per_epoch = cifar10.Cifar10DataSet.num_examples_per_epoch( num_batches_per_epoch = cifar10.Cifar10DataSet.num_examples_per_epoch(
'train') // FLAGS.train_batch_size 'train') // (FLAGS.train_batch_size * FLAGS.num_workers)
boundaries = [ boundaries = [
num_batches_per_epoch * x num_batches_per_epoch * x
for x in np.array([82, 123, 300], dtype=np.int64) for x in np.array([82, 123, 300], dtype=np.int64)
...@@ -234,6 +317,14 @@ def _resnet_model_fn(features, labels, mode): ...@@ -234,6 +317,14 @@ def _resnet_model_fn(features, labels, mode):
optimizer = tf.train.MomentumOptimizer( optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=momentum) learning_rate=learning_rate, momentum=momentum)
chief_hooks = []
if FLAGS.sync:
optimizer = tf.train.SyncReplicasOptimizer(
optimizer,
replicas_to_aggregate=FLAGS.num_workers)
sync_replicas_hook = optimizer.make_session_run_hook(True)
chief_hooks.append(sync_replicas_hook)
# Create single grouped train op # Create single grouped train op
train_op = [ train_op = [
optimizer.apply_gradients( optimizer.apply_gradients(
...@@ -258,6 +349,7 @@ def _resnet_model_fn(features, labels, mode): ...@@ -258,6 +349,7 @@ def _resnet_model_fn(features, labels, mode):
predictions=predictions, predictions=predictions,
loss=loss, loss=loss,
train_op=train_op, train_op=train_op,
training_chief_hooks=chief_hooks,
eval_metric_ops=metrics) eval_metric_ops=metrics)
...@@ -312,7 +404,8 @@ def input_fn(subset, num_shards): ...@@ -312,7 +404,8 @@ def input_fn(subset, num_shards):
elif subset == 'validate' or subset == 'eval': elif subset == 'validate' or subset == 'eval':
batch_size = FLAGS.eval_batch_size batch_size = FLAGS.eval_batch_size
else: else:
raise ValueError('Subset must be one of \'train\', \'validate\' and \'eval\'') raise ValueError('Subset must be one of \'train\''
', \'validate\' and \'eval\'')
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
use_distortion = subset == 'train' and FLAGS.use_distortion_for_training use_distortion = subset == 'train' and FLAGS.use_distortion_for_training
dataset = cifar10.Cifar10DataSet(FLAGS.data_dir, subset, use_distortion) dataset = cifar10.Cifar10DataSet(FLAGS.data_dir, subset, use_distortion)
...@@ -338,6 +431,33 @@ def input_fn(subset, num_shards): ...@@ -338,6 +431,33 @@ def input_fn(subset, num_shards):
return feature_shards, label_shards return feature_shards, label_shards
# create experiment
def get_experiment_fn(train_input_fn, eval_input_fn, train_steps, eval_steps,
train_hooks):
"""Returns an Experiment function.
Experiments perform training on several workers in parallel,
in other words experiments know how to invoke train and eval in a sensible
fashion for distributed training.
"""
def _experiment_fn(run_config, hparams):
"""Returns an Experiment."""
del hparams # unused arg
# create estimator
classifier = tf.estimator.Estimator(model_fn=_resnet_model_fn,
config=run_config)
experiment = tf.contrib.learn.Experiment(
classifier,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
train_steps=train_steps,
eval_steps=eval_steps)
# adding hooks to estimator on training mode
experiment.extend_train_hooks(train_hooks)
return experiment
return _experiment_fn
def main(unused_argv): def main(unused_argv):
# The env variable is on deprecation path, default is set to off. # The env variable is on deprecation path, default is set to off.
os.environ['TF_SYNC_ON_FINISH'] = '0' os.environ['TF_SYNC_ON_FINISH'] = '0'
...@@ -359,36 +479,60 @@ def main(unused_argv): ...@@ -359,36 +479,60 @@ def main(unused_argv):
if num_eval_examples % FLAGS.eval_batch_size != 0: if num_eval_examples % FLAGS.eval_batch_size != 0:
raise ValueError('validation set size must be multiple of eval_batch_size') raise ValueError('validation set size must be multiple of eval_batch_size')
config = tf.estimator.RunConfig() train_input_fn = functools.partial(input_fn, subset='train',
num_shards=FLAGS.num_gpus)
eval_input_fn = functools.partial(input_fn, subset='eval',
num_shards=FLAGS.num_gpus)
train_steps = FLAGS.train_steps
eval_steps = num_eval_examples // FLAGS.eval_batch_size
# Session configuration.
sess_config = tf.ConfigProto() sess_config = tf.ConfigProto()
sess_config.allow_soft_placement = True sess_config.allow_soft_placement = True
sess_config.log_device_placement = FLAGS.log_device_placement sess_config.log_device_placement = FLAGS.log_device_placement
sess_config.intra_op_parallelism_threads = FLAGS.num_intra_threads sess_config.intra_op_parallelism_threads = FLAGS.num_intra_threads
sess_config.inter_op_parallelism_threads = FLAGS.num_inter_threads sess_config.inter_op_parallelism_threads = FLAGS.num_inter_threads
sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible sess_config.gpu_options.force_gpu_compatible = FLAGS.force_gpu_compatible
config = config.replace(session_config=sess_config)
classifier = tf.estimator.Estimator( # Hooks that add extra logging that is useful to see the loss more often in
model_fn=_resnet_model_fn, model_dir=FLAGS.model_dir, config=config) # the console as well as examples per second.
tensors_to_log = {'learning_rate': 'learning_rate',
'loss': 'gradient_averaging/loss'}
tensors_to_log = {'learning_rate': 'learning_rate'}
logging_hook = tf.train.LoggingTensorHook( logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log, every_n_iter=100) tensors=tensors_to_log, every_n_iter=100)
print('Starting to train...') examples_sec_hook = ExamplesPerSecondHook(
classifier.train( FLAGS.train_batch_size, every_n_steps=10)
input_fn=functools.partial(
input_fn, subset='train', num_shards=FLAGS.num_gpus), hooks = [logging_hook, examples_sec_hook]
steps=FLAGS.train_steps,
hooks=[logging_hook])
print('Starting to evaluate...') if FLAGS.run_experiment:
eval_results = classifier.evaluate( config = tf.contrib.learn.RunConfig(model_dir=FLAGS.model_dir)
input_fn=functools.partial( config = config.replace(session_config=sess_config)
input_fn, subset='eval', num_shards=FLAGS.num_gpus), tf.contrib.learn.learn_runner.run(
steps=num_eval_examples // FLAGS.eval_batch_size) get_experiment_fn(train_input_fn, eval_input_fn,
print(eval_results) train_steps, eval_steps,
hooks), run_config=config)
else:
config = tf.estimator.RunConfig()
config = config.replace(session_config=sess_config)
classifier = tf.estimator.Estimator(
model_fn=_resnet_model_fn, model_dir=FLAGS.model_dir, config=config)
print('Starting to train...')
classifier.train(input_fn=train_input_fn,
steps=train_steps,
hooks=hooks)
print('Starting to evaluate...')
eval_results = classifier.evaluate(
input_fn=eval_input_fn,
steps=eval_steps)
print(eval_results)
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() tf.app.run()
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Read CIFAR-10 data from pickled numpy arrays and write TFExamples.
Generates TFRecord files from the python version of the CIFAR-10 dataset
downloaded from https://www.cs.toronto.edu/~kriz/cifar.html.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cPickle
import os
import tensorflow as tf
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_string('input_dir', '',
'Directory where CIFAR10 data is located.')
tf.flags.DEFINE_string('output_dir', '',
'Directory where TFRecords will be saved.'
'The TFRecords will have the same name as'
' the CIFAR10 inputs + .tfrecords.')
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(value)]))
def _get_file_names():
"""Returns the file names expected to exist in the input_dir."""
file_names = ['data_batch_%d' % i for i in xrange(1, 6)]
file_names.append('test_batch')
return file_names
def read_pickle_from_file(filename):
with open(filename, 'r') as f:
data_dict = cPickle.load(f)
return data_dict
def main(argv):
del argv # Unused.
file_names = _get_file_names()
for file_name in file_names:
input_file = os.path.join(FLAGS.input_dir, file_name)
output_file = os.path.join(FLAGS.output_dir, file_name + '.tfrecords')
print('Generating %s' % output_file)
record_writer = tf.python_io.TFRecordWriter(output_file)
data_dict = read_pickle_from_file(input_file)
data = data_dict['data']
labels = data_dict['labels']
num_entries_in_batch = len(labels)
for i in range(num_entries_in_batch):
example = tf.train.Example(
features=tf.train.Features(feature={
'image': _bytes_feature(data[i].tobytes()),
'label': _int64_feature(labels[i])
}))
record_writer.write(example.SerializeToString())
record_writer.close()
print('Done!')
if __name__ == '__main__':
tf.app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment