Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9b98e3db
Commit
9b98e3db
authored
Nov 21, 2019
by
Allen Wang
Committed by
saberkun
Nov 21, 2019
Browse files
Internal change
PiperOrigin-RevId: 281793430
parent
3635527d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
12 deletions
+34
-12
official/vision/image_classification/common.py
official/vision/image_classification/common.py
+34
-12
No files found.
official/vision/image_classification/common.py
View file @
9b98e3db
...
@@ -356,6 +356,35 @@ def define_keras_flags(dynamic_loss_scale=True):
...
@@ -356,6 +356,35 @@ def define_keras_flags(dynamic_loss_scale=True):
'steps per epoch.'
)
'steps per epoch.'
)
def
get_synth_data
(
height
,
width
,
num_channels
,
num_classes
,
dtype
):
"""Creates a set of synthetic random data.
Args:
height: Integer height that will be used to create a fake image tensor.
width: Integer width that will be used to create a fake image tensor.
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
Returns:
A tuple of tensors representing the inputs and labels.
"""
# Synthetic input should be within [0, 255].
inputs
=
tf
.
random
.
truncated_normal
([
height
,
width
,
num_channels
],
dtype
=
dtype
,
mean
=
127
,
stddev
=
60
,
name
=
'synthetic_inputs'
)
labels
=
tf
.
random
.
uniform
([
1
],
minval
=
0
,
maxval
=
num_classes
-
1
,
dtype
=
tf
.
int32
,
name
=
'synthetic_labels'
)
return
inputs
,
labels
def
get_synth_input_fn
(
height
,
width
,
num_channels
,
num_classes
,
def
get_synth_input_fn
(
height
,
width
,
num_channels
,
num_classes
,
dtype
=
tf
.
float32
,
drop_remainder
=
True
):
dtype
=
tf
.
float32
,
drop_remainder
=
True
):
"""Returns an input function that returns a dataset with random data.
"""Returns an input function that returns a dataset with random data.
...
@@ -382,20 +411,13 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
...
@@ -382,20 +411,13 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
# pylint: disable=unused-argument
# pylint: disable=unused-argument
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
*
args
,
**
kwargs
):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
*
args
,
**
kwargs
):
"""Returns dataset filled with random data."""
"""Returns dataset filled with random data."""
# Synthetic input should be within [0, 255].
inputs
,
labels
=
get_synth_data
(
height
=
height
,
inputs
=
tf
.
random
.
truncated_normal
([
height
,
width
,
num_channels
],
width
=
width
,
dtype
=
dtype
,
num_channels
=
num_channels
,
mean
=
127
,
num_classes
=
num_classes
,
stddev
=
60
,
dtype
=
dtype
)
name
=
'synthetic_inputs'
)
labels
=
tf
.
random
.
uniform
([
1
],
minval
=
0
,
maxval
=
num_classes
-
1
,
dtype
=
tf
.
int32
,
name
=
'synthetic_labels'
)
# Cast to float32 for Keras model.
# Cast to float32 for Keras model.
labels
=
tf
.
cast
(
labels
,
dtype
=
tf
.
float32
)
labels
=
tf
.
cast
(
labels
,
dtype
=
tf
.
float32
)
data
=
tf
.
data
.
Dataset
.
from_tensors
((
inputs
,
labels
)).
repeat
()
data
=
tf
.
data
.
Dataset
.
from_tensors
((
inputs
,
labels
)).
repeat
()
# `drop_remainder` will make dataset produce outputs with known shapes.
# `drop_remainder` will make dataset produce outputs with known shapes.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment