Commit 2f8481da authored by Shining Sun's avatar Shining Sun
Browse files

Synth data works

parent db80d57a
...@@ -72,28 +72,60 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batc ...@@ -72,28 +72,60 @@ def learning_rate_schedule(current_epoch, current_batch, batches_per_epoch, batc
def parse_record_keras(raw_record, is_training, dtype): def parse_record_keras(raw_record, is_training, dtype):
"""Adjust the shape of label.""" """Adjust the shape of label."""
image_buffer, label, bbox = imagenet_main._parse_example_proto(raw_record)
image = imagenet_preprocessing.preprocess_image(
image_buffer=image_buffer,
bbox=bbox,
output_height=imagenet_main.DEFAULT_IMAGE_SIZE,
output_width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS,
is_training=is_training)
image = tf.cast(image, dtype)
label = tf.sparse_to_dense(label, (imagenet_main.NUM_CLASSES,), 1)
"""
image, label = imagenet_main.parse_record(raw_record, is_training, dtype) image, label = imagenet_main.parse_record(raw_record, is_training, dtype)
# Subtract one so that labels are in [0, 1000), and cast to float32 for # Subtract one so that labels are in [0, 1000), and cast to float32 for
# Keras model. # Keras model.
label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1, label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
dtype=tf.float32) dtype=tf.float32)
"""
return image, label return image, label
def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32):
"""Returns an input function that returns a dataset with random data.
This input_fn returns a data set that iterates over a set of random data and
bypasses all preprocessing, e.g. jpeg decode and copy. The host to device
copy is still included. This used to find the upper throughput bound when
tunning the full input pipeline.
Args:
height: Integer height that will be used to create a fake image tensor.
width: Integer width that will be used to create a fake image tensor.
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
Returns:
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
# pylint: disable=unused-argument
def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
"""Returns dataset filled with random data."""
# Synthetic input should be within [0, 255].
inputs = tf.truncated_normal(
[batch_size] + [height, width, num_channels],
dtype=dtype,
mean=127,
stddev=60,
name='synthetic_inputs')
labels = tf.random_uniform(
[batch_size] + [1],
minval=0,
maxval=num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return data
return input_fn
def run_imagenet_with_keras(flags_obj): def run_imagenet_with_keras(flags_obj):
"""Run ResNet ImageNet training and eval loop using native Keras APIs. """Run ResNet ImageNet training and eval loop using native Keras APIs.
...@@ -116,51 +148,38 @@ def run_imagenet_with_keras(flags_obj): ...@@ -116,51 +148,38 @@ def run_imagenet_with_keras(flags_obj):
# pylint: disable=protected-access # pylint: disable=protected-access
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
synth_input_fn = resnet_run_loop.get_synth_input_fn( input_fn = get_synth_input_fn(
imagenet_main.DEFAULT_IMAGE_SIZE, imagenet_main.DEFAULT_IMAGE_SIZE,
imagenet_main.NUM_CHANNELS, imagenet_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj))
train_input_dataset = synth_input_fn(
batch_size=per_device_batch_size,
height=imagenet_main.DEFAULT_IMAGE_SIZE,
width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype)
eval_input_dataset = synth_input_fn(
batch_size=per_device_batch_size,
height=imagenet_main.DEFAULT_IMAGE_SIZE, height=imagenet_main.DEFAULT_IMAGE_SIZE,
width=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS, num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main.NUM_CLASSES, num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype) dtype=flags_core.get_tf_dtype(flags_obj))
# pylint: enable=protected-access
else: else:
train_input_dataset = imagenet_main.input_fn( input_fn = imagenet_main.input_fn
True,
flags_obj.data_dir, train_input_dataset = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size, batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras) parse_record_fn=parse_record_keras)
eval_input_dataset = imagenet_main.input_fn( eval_input_dataset = input_fn(
False, is_training=False,
flags_obj.data_dir, data_dir=flags_obj.data_dir,
batch_size=per_device_batch_size, batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras) parse_record_fn=parse_record_keras)
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.use_one_device_strategy) flags_obj.num_gpus, flags_obj.use_one_device_strategy)
model = resnet50.ResNet50(num_classes=imagenet_main.NUM_CLASSES) model = resnet50.ResNet50(num_classes=imagenet_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy', model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
metrics=['categorical_accuracy'], metrics=['sparse_categorical_accuracy'],
distribute=strategy) distribute=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks( time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment