"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "df9d0b474ed8cb4f4dff1aa7b196700c6e81c104"
Commit 9b98e3db authored by Allen Wang's avatar Allen Wang Committed by saberkun
Browse files

Internal change

PiperOrigin-RevId: 281793430
parent 3635527d
...@@ -356,6 +356,35 @@ def define_keras_flags(dynamic_loss_scale=True): ...@@ -356,6 +356,35 @@ def define_keras_flags(dynamic_loss_scale=True):
'steps per epoch.') 'steps per epoch.')
def get_synth_data(height, width, num_channels, num_classes, dtype):
"""Creates a set of synthetic random data.
Args:
height: Integer height that will be used to create a fake image tensor.
width: Integer width that will be used to create a fake image tensor.
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
Returns:
A tuple of tensors representing the inputs and labels.
"""
# Synthetic input should be within [0, 255].
inputs = tf.random.truncated_normal([height, width, num_channels],
dtype=dtype,
mean=127,
stddev=60,
name='synthetic_inputs')
labels = tf.random.uniform([1],
minval=0,
maxval=num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
return inputs, labels
def get_synth_input_fn(height, width, num_channels, num_classes, def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32, drop_remainder=True): dtype=tf.float32, drop_remainder=True):
"""Returns an input function that returns a dataset with random data. """Returns an input function that returns a dataset with random data.
...@@ -382,20 +411,13 @@ def get_synth_input_fn(height, width, num_channels, num_classes, ...@@ -382,20 +411,13 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
# pylint: disable=unused-argument # pylint: disable=unused-argument
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
"""Returns dataset filled with random data.""" """Returns dataset filled with random data."""
# Synthetic input should be within [0, 255]. inputs, labels = get_synth_data(height=height,
inputs = tf.random.truncated_normal([height, width, num_channels], width=width,
dtype=dtype, num_channels=num_channels,
mean=127, num_classes=num_classes,
stddev=60, dtype=dtype)
name='synthetic_inputs')
labels = tf.random.uniform([1],
minval=0,
maxval=num_classes - 1,
dtype=tf.int32,
name='synthetic_labels')
# Cast to float32 for Keras model. # Cast to float32 for Keras model.
labels = tf.cast(labels, dtype=tf.float32) labels = tf.cast(labels, dtype=tf.float32)
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat() data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
# `drop_remainder` will make dataset produce outputs with known shapes. # `drop_remainder` will make dataset produce outputs with known shapes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment