"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "41038d54947eabd716e75c5f3f98c829a4a4cd37"
Commit 80dcd27c authored by Shining Sun's avatar Shining Sun
Browse files

Synth data for cifar

parent 2f8481da
...@@ -101,46 +101,29 @@ def run(flags_obj): ...@@ -101,46 +101,29 @@ def run(flags_obj):
per_device_batch_size = distribution_utils.per_device_batch_size( per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)) flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))
# pylint: disable=protected-access
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
synth_input_fn = resnet_run_loop.get_synth_input_fn( input_fn = keras_common.get_synth_input_fn(
cifar_main.HEIGHT, cifar_main.WIDTH,
cifar_main.NUM_CHANNELS, cifar_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj))
train_input_dataset = synth_input_fn(
True,
flags_obj.data_dir,
batch_size=per_device_batch_size,
height=cifar_main.HEIGHT, height=cifar_main.HEIGHT,
width=cifar_main.WIDTH, width=cifar_main.WIDTH,
num_channels=cifar_main.NUM_CHANNELS, num_channels=cifar_main.NUM_CHANNELS,
num_classes=cifar_main.NUM_CLASSES, num_classes=cifar_main.NUM_CLASSES,
dtype=dtype) dtype=flags_core.get_tf_dtype(flags_obj))
eval_input_dataset = synth_input_fn(
False,
flags_obj.data_dir,
batch_size=per_device_batch_size,
height=cifar_main.HEIGHT,
width=cifar_main.WIDTH,
num_channels=cifar_main.NUM_CHANNELS,
num_classes=cifar_main.NUM_CLASSES,
dtype=dtype)
# pylint: enable=protected-access
else: else:
train_input_dataset = cifar_main.input_fn( input_fn = cifar_main.input_fn
True,
flags_obj.data_dir, train_input_dataset = input_fn(
batch_size=per_device_batch_size, is_training=True,
num_epochs=flags_obj.train_epochs, data_dir=flags_obj.data_dir,
parse_record_fn=parse_record_keras) batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs,
eval_input_dataset = cifar_main.input_fn( parse_record_fn=parse_record_keras)
False,
flags_obj.data_dir, eval_input_dataset = input_fn(
batch_size=per_device_batch_size, is_training=False,
num_epochs=flags_obj.train_epochs, data_dir=flags_obj.data_dir,
parse_record_fn=parse_record_keras) batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras)
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
......
...@@ -144,3 +144,48 @@ def define_keras_flags(): ...@@ -144,3 +144,48 @@ def define_keras_flags():
name="train_steps", default=None, name="train_steps", default=None,
help="The number of steps to run for training") help="The number of steps to run for training")
def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32):
"""Returns an input function that returns a dataset with random data.
This input_fn returns a data set that iterates over a set of random data and
bypasses all preprocessing, e.g. jpeg decode and copy. The host to device
copy is still included. This used to find the upper throughput bound when
tunning the full input pipeline.
Args:
height: Integer height that will be used to create a fake image tensor.
width: Integer width that will be used to create a fake image tensor.
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
Returns:
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
# pylint: disable=unused-argument
def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
"""Returns dataset filled with random data."""
# Synthetic input should be within [0, 255].
inputs = tf.truncated_normal(
[batch_size] + [height, width, num_channels],
dtype=dtype,
mean=127,
stddev=60,
name='synthetic_inputs')
labels = tf.random_uniform(
[batch_size] + [1],
minval=0,
maxval=num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return data
return input_fn
...@@ -81,51 +81,6 @@ def parse_record_keras(raw_record, is_training, dtype): ...@@ -81,51 +81,6 @@ def parse_record_keras(raw_record, is_training, dtype):
return image, label return image, label
def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32):
"""Returns an input function that returns a dataset with random data.
This input_fn returns a data set that iterates over a set of random data and
bypasses all preprocessing, e.g. jpeg decode and copy. The host to device
copy is still included. This used to find the upper throughput bound when
tunning the full input pipeline.
Args:
height: Integer height that will be used to create a fake image tensor.
width: Integer width that will be used to create a fake image tensor.
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
Returns:
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
# pylint: disable=unused-argument
def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
"""Returns dataset filled with random data."""
# Synthetic input should be within [0, 255].
inputs = tf.truncated_normal(
[batch_size] + [height, width, num_channels],
dtype=dtype,
mean=127,
stddev=60,
name='synthetic_inputs')
labels = tf.random_uniform(
[batch_size] + [1],
minval=0,
maxval=num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
return data
return input_fn
def run_imagenet_with_keras(flags_obj): def run_imagenet_with_keras(flags_obj):
"""Run ResNet ImageNet training and eval loop using native Keras APIs. """Run ResNet ImageNet training and eval loop using native Keras APIs.
...@@ -148,7 +103,7 @@ def run_imagenet_with_keras(flags_obj): ...@@ -148,7 +103,7 @@ def run_imagenet_with_keras(flags_obj):
# pylint: disable=protected-access # pylint: disable=protected-access
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
input_fn = get_synth_input_fn( input_fn = keras_common.get_synth_input_fn(
height=imagenet_main.DEFAULT_IMAGE_SIZE, height=imagenet_main.DEFAULT_IMAGE_SIZE,
width=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS, num_channels=imagenet_main.NUM_CHANNELS,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment