Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
481728db
Unverified
Commit
481728db
authored
Sep 04, 2018
by
Toby Boyd
Committed by
GitHub
Sep 04, 2018
Browse files
Merge pull request #5225 from tfboyd/resnet_synthetic_fix
ResNet synthetic data performance enhancement.
parents
e0f6a392
967133c1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
44 additions
and
23 deletions
+44
-23
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+5
-4
official/resnet/cifar10_test.py
official/resnet/cifar10_test.py
+2
-2
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+6
-4
official/resnet/imagenet_test.py
official/resnet/imagenet_test.py
+2
-2
official/resnet/resnet_run_loop.py
official/resnet/resnet_run_loop.py
+29
-11
No files found.
official/resnet/cifar10_main.py
View file @
481728db
...
...
@@ -135,9 +135,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
)
def
get_synth_input_fn
():
def
get_synth_input_fn
(
dtype
):
return
resnet_run_loop
.
get_synth_input_fn
(
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
,
_NUM_CLASSES
)
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
,
_NUM_CLASSES
,
dtype
=
dtype
)
###############################################################################
...
...
@@ -243,8 +243,9 @@ def run_cifar(flags_obj):
Args:
flags_obj: An object containing parsed flag values.
"""
input_function
=
(
flags_obj
.
use_synthetic_data
and
get_synth_input_fn
()
or
input_fn
)
input_function
=
(
flags_obj
.
use_synthetic_data
and
get_synth_input_fn
(
flags_core
.
get_tf_dtype
(
flags_obj
))
or
input_fn
)
resnet_run_loop
.
resnet_main
(
flags_obj
,
cifar10_model_fn
,
input_function
,
DATASET_NAME
,
shape
=
[
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
])
...
...
official/resnet/cifar10_test.py
View file @
481728db
...
...
@@ -77,9 +77,9 @@ class BaseTest(tf.test.TestCase):
self
.
assertAllClose
(
pixel
,
np
.
array
([
-
1.225
,
0.
,
1.225
]),
rtol
=
1e-3
)
def
cifar10_model_fn_helper
(
self
,
mode
,
resnet_version
,
dtype
):
input_fn
=
cifar10_main
.
get_synth_input_fn
()
input_fn
=
cifar10_main
.
get_synth_input_fn
(
dtype
)
dataset
=
input_fn
(
True
,
''
,
_BATCH_SIZE
)
iterator
=
dataset
.
make_
one_shot
_iterator
()
iterator
=
dataset
.
make_
initializable
_iterator
()
features
,
labels
=
iterator
.
get_next
()
spec
=
cifar10_main
.
cifar10_model_fn
(
features
,
labels
,
mode
,
{
...
...
official/resnet/imagenet_main.py
View file @
481728db
...
...
@@ -196,9 +196,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
)
def
get_synth_input_fn
():
def
get_synth_input_fn
(
dtype
):
return
resnet_run_loop
.
get_synth_input_fn
(
_DEFAULT_IMAGE_SIZE
,
_DEFAULT_IMAGE_SIZE
,
_NUM_CHANNELS
,
_NUM_CLASSES
)
_DEFAULT_IMAGE_SIZE
,
_DEFAULT_IMAGE_SIZE
,
_NUM_CHANNELS
,
_NUM_CLASSES
,
dtype
=
dtype
)
###############################################################################
...
...
@@ -331,8 +332,9 @@ def run_imagenet(flags_obj):
Args:
flags_obj: An object containing parsed flag values.
"""
input_function
=
(
flags_obj
.
use_synthetic_data
and
get_synth_input_fn
()
or
input_fn
)
input_function
=
(
flags_obj
.
use_synthetic_data
and
get_synth_input_fn
(
flags_core
.
get_tf_dtype
(
flags_obj
))
or
input_fn
)
resnet_run_loop
.
resnet_main
(
flags_obj
,
imagenet_model_fn
,
input_function
,
DATASET_NAME
,
...
...
official/resnet/imagenet_test.py
View file @
481728db
...
...
@@ -191,9 +191,9 @@ class BaseTest(tf.test.TestCase):
"""Tests that the EstimatorSpec is given the appropriate arguments."""
tf
.
train
.
create_global_step
()
input_fn
=
imagenet_main
.
get_synth_input_fn
()
input_fn
=
imagenet_main
.
get_synth_input_fn
(
dtype
)
dataset
=
input_fn
(
True
,
''
,
_BATCH_SIZE
)
iterator
=
dataset
.
make_
one_shot
_iterator
()
iterator
=
dataset
.
make_
initializable
_iterator
()
features
,
labels
=
iterator
.
get_next
()
spec
=
imagenet_main
.
imagenet_model_fn
(
features
,
labels
,
mode
,
{
...
...
official/resnet/resnet_run_loop.py
View file @
481728db
...
...
@@ -108,11 +108,14 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
return
dataset
def
get_synth_input_fn
(
height
,
width
,
num_channels
,
num_classes
):
"""Returns an input function that returns a dataset with zeroes.
def
get_synth_input_fn
(
height
,
width
,
num_channels
,
num_classes
,
dtype
=
tf
.
float32
):
"""Returns an input function that returns a dataset with random data.
This is useful in debugging input pipeline performance, as it removes all
elements of file reading and image preprocessing.
This input_fn returns a data set that iterates over a set of random data and
bypasses all preprocessing, e.g. jpeg decode and copy. The host to device
copy is still included. This used to find the upper throughput bound when
tunning the full input pipeline.
Args:
height: Integer height that will be used to create a fake image tensor.
...
...
@@ -120,17 +123,32 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
Returns:
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
return
model_helpers
.
generate_synthetic_data
(
input_shape
=
tf
.
TensorShape
([
batch_size
,
height
,
width
,
num_channels
]),
input_dtype
=
tf
.
float32
,
label_shape
=
tf
.
TensorShape
([
batch_size
]),
label_dtype
=
tf
.
int32
)
# pylint: disable=unused-argument
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
*
args
,
**
kwargs
):
"""Returns dataset filled with random data."""
# Synthetic input should be within [0, 255].
inputs
=
tf
.
truncated_normal
(
[
batch_size
]
+
[
height
,
width
,
num_channels
],
dtype
=
dtype
,
mean
=
127
,
stddev
=
60
,
name
=
'synthetic_inputs'
)
labels
=
tf
.
random_uniform
(
[
batch_size
],
minval
=
0
,
maxval
=
num_classes
-
1
,
dtype
=
tf
.
int32
,
name
=
'synthetic_labels'
)
data
=
tf
.
data
.
Dataset
.
from_tensors
((
inputs
,
labels
)).
repeat
()
data
=
data
.
prefetch
(
buffer_size
=
tf
.
contrib
.
data
.
AUTOTUNE
)
return
data
return
input_fn
...
...
@@ -230,7 +248,7 @@ def resnet_model_fn(features, labels, mode, model_class,
# Generate a summary node for the images
tf
.
summary
.
image
(
'images'
,
features
,
max_outputs
=
6
)
# TODO(tobyboyd): Add cast as part of input pipeline on cpu and remove.
features
=
tf
.
cast
(
features
,
dtype
)
model
=
model_class
(
resnet_size
,
data_format
,
resnet_version
=
resnet_version
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment