Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c9972ad6
"docs/vscode:/vscode.git/clone" did not exist on "4fd6e7103006e01f7a4f5d723b13ea0e789ff3ce"
Commit
c9972ad6
authored
Sep 02, 2018
by
Toby Boyd
Browse files
Improve synthic data performance
parent
23b5b422
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
42 additions
and
23 deletions
+42
-23
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+5
-4
official/resnet/cifar10_test.py
official/resnet/cifar10_test.py
+2
-2
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+6
-4
official/resnet/imagenet_test.py
official/resnet/imagenet_test.py
+2
-2
official/resnet/resnet_run_loop.py
official/resnet/resnet_run_loop.py
+27
-11
No files found.
official/resnet/cifar10_main.py
View file @
c9972ad6
...
@@ -135,9 +135,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
...
@@ -135,9 +135,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
)
)
def
get_synth_input_fn
():
def
get_synth_input_fn
(
dtype
):
return
resnet_run_loop
.
get_synth_input_fn
(
return
resnet_run_loop
.
get_synth_input_fn
(
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
,
_NUM_CLASSES
)
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
,
_NUM_CLASSES
,
dtype
=
dtype
)
###############################################################################
###############################################################################
...
@@ -243,8 +243,9 @@ def run_cifar(flags_obj):
...
@@ -243,8 +243,9 @@ def run_cifar(flags_obj):
Args:
Args:
flags_obj: An object containing parsed flag values.
flags_obj: An object containing parsed flag values.
"""
"""
input_function
=
(
flags_obj
.
use_synthetic_data
and
get_synth_input_fn
()
input_function
=
(
flags_obj
.
use_synthetic_data
and
or
input_fn
)
get_synth_input_fn
(
flags_core
.
get_tf_dtype
(
flags_obj
))
or
input_fn
)
resnet_run_loop
.
resnet_main
(
resnet_run_loop
.
resnet_main
(
flags_obj
,
cifar10_model_fn
,
input_function
,
DATASET_NAME
,
flags_obj
,
cifar10_model_fn
,
input_function
,
DATASET_NAME
,
shape
=
[
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
])
shape
=
[
_HEIGHT
,
_WIDTH
,
_NUM_CHANNELS
])
...
...
official/resnet/cifar10_test.py
View file @
c9972ad6
...
@@ -77,9 +77,9 @@ class BaseTest(tf.test.TestCase):
...
@@ -77,9 +77,9 @@ class BaseTest(tf.test.TestCase):
self
.
assertAllClose
(
pixel
,
np
.
array
([
-
1.225
,
0.
,
1.225
]),
rtol
=
1e-3
)
self
.
assertAllClose
(
pixel
,
np
.
array
([
-
1.225
,
0.
,
1.225
]),
rtol
=
1e-3
)
def
cifar10_model_fn_helper
(
self
,
mode
,
resnet_version
,
dtype
):
def
cifar10_model_fn_helper
(
self
,
mode
,
resnet_version
,
dtype
):
input_fn
=
cifar10_main
.
get_synth_input_fn
()
input_fn
=
cifar10_main
.
get_synth_input_fn
(
dtype
)
dataset
=
input_fn
(
True
,
''
,
_BATCH_SIZE
)
dataset
=
input_fn
(
True
,
''
,
_BATCH_SIZE
)
iterator
=
dataset
.
make_
one_shot
_iterator
()
iterator
=
dataset
.
make_
initializable
_iterator
()
features
,
labels
=
iterator
.
get_next
()
features
,
labels
=
iterator
.
get_next
()
spec
=
cifar10_main
.
cifar10_model_fn
(
spec
=
cifar10_main
.
cifar10_model_fn
(
features
,
labels
,
mode
,
{
features
,
labels
,
mode
,
{
...
...
official/resnet/imagenet_main.py
View file @
c9972ad6
...
@@ -196,9 +196,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
...
@@ -196,9 +196,10 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
)
)
def
get_synth_input_fn
():
def
get_synth_input_fn
(
dtype
):
return
resnet_run_loop
.
get_synth_input_fn
(
return
resnet_run_loop
.
get_synth_input_fn
(
_DEFAULT_IMAGE_SIZE
,
_DEFAULT_IMAGE_SIZE
,
_NUM_CHANNELS
,
_NUM_CLASSES
)
_DEFAULT_IMAGE_SIZE
,
_DEFAULT_IMAGE_SIZE
,
_NUM_CHANNELS
,
_NUM_CLASSES
,
dtype
=
dtype
)
###############################################################################
###############################################################################
...
@@ -331,8 +332,9 @@ def run_imagenet(flags_obj):
...
@@ -331,8 +332,9 @@ def run_imagenet(flags_obj):
Args:
Args:
flags_obj: An object containing parsed flag values.
flags_obj: An object containing parsed flag values.
"""
"""
input_function
=
(
flags_obj
.
use_synthetic_data
and
get_synth_input_fn
()
input_function
=
(
flags_obj
.
use_synthetic_data
and
or
input_fn
)
get_synth_input_fn
(
flags_core
.
get_tf_dtype
(
flags_obj
))
or
input_fn
)
resnet_run_loop
.
resnet_main
(
resnet_run_loop
.
resnet_main
(
flags_obj
,
imagenet_model_fn
,
input_function
,
DATASET_NAME
,
flags_obj
,
imagenet_model_fn
,
input_function
,
DATASET_NAME
,
...
...
official/resnet/imagenet_test.py
View file @
c9972ad6
...
@@ -191,9 +191,9 @@ class BaseTest(tf.test.TestCase):
...
@@ -191,9 +191,9 @@ class BaseTest(tf.test.TestCase):
"""Tests that the EstimatorSpec is given the appropriate arguments."""
"""Tests that the EstimatorSpec is given the appropriate arguments."""
tf
.
train
.
create_global_step
()
tf
.
train
.
create_global_step
()
input_fn
=
imagenet_main
.
get_synth_input_fn
()
input_fn
=
imagenet_main
.
get_synth_input_fn
(
dtype
)
dataset
=
input_fn
(
True
,
''
,
_BATCH_SIZE
)
dataset
=
input_fn
(
True
,
''
,
_BATCH_SIZE
)
iterator
=
dataset
.
make_
one_shot
_iterator
()
iterator
=
dataset
.
make_
initializable
_iterator
()
features
,
labels
=
iterator
.
get_next
()
features
,
labels
=
iterator
.
get_next
()
spec
=
imagenet_main
.
imagenet_model_fn
(
spec
=
imagenet_main
.
imagenet_model_fn
(
features
,
labels
,
mode
,
{
features
,
labels
,
mode
,
{
...
...
official/resnet/resnet_run_loop.py
View file @
c9972ad6
...
@@ -108,11 +108,12 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
...
@@ -108,11 +108,12 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
return
dataset
return
dataset
def
get_synth_input_fn
(
height
,
width
,
num_channels
,
num_classes
):
def
get_synth_input_fn
(
height
,
width
,
num_channels
,
num_classes
,
"""Returns an input function that returns a dataset with zeroes.
dtype
=
tf
.
float32
):
"""Returns an input function that returns a dataset with random data.
This i
s useful in debugging input pipeline performance, as it removes all
This i
nput_fn removed all aspects of the input pipeline other than the
elements of file reading and image preprocessing
.
host to device copy. This is useful in debugging input pipeline performance
.
Args:
Args:
height: Integer height that will be used to create a fake image tensor.
height: Integer height that will be used to create a fake image tensor.
...
@@ -120,17 +121,32 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
...
@@ -120,17 +121,32 @@ def get_synth_input_fn(height, width, num_channels, num_classes):
num_channels: Integer depth that will be used to create a fake image tensor.
num_channels: Integer depth that will be used to create a fake image tensor.
num_classes: Number of classes that should be represented in the fake labels
num_classes: Number of classes that should be represented in the fake labels
tensor
tensor
dtype: Data type for features/images.
Returns:
Returns:
An input_fn that can be used in place of a real one to return a dataset
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
that can be used for iteration.
"""
"""
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
# pylint: disable=unused-argument
return
model_helpers
.
generate_synthetic_data
(
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
*
args
,
**
kwargs
):
input_shape
=
tf
.
TensorShape
([
batch_size
,
height
,
width
,
num_channels
]),
"""Returns dataset filled with random data."""
input_dtype
=
tf
.
float32
,
# Synthetic input should be within [0, 255].
label_shape
=
tf
.
TensorShape
([
batch_size
]),
inputs
=
tf
.
truncated_normal
(
label_dtype
=
tf
.
int32
)
[
batch_size
]
+
[
height
,
width
,
num_channels
],
dtype
=
dtype
,
mean
=
127
,
stddev
=
60
,
name
=
'synthetic_inputs'
)
labels
=
tf
.
random_uniform
(
[
batch_size
],
minval
=
0
,
maxval
=
num_classes
-
1
,
dtype
=
tf
.
int32
,
name
=
'synthetic_labels'
)
data
=
tf
.
data
.
Dataset
.
from_tensors
((
inputs
,
labels
)).
repeat
()
data
=
data
.
prefetch
(
buffer_size
=
tf
.
contrib
.
data
.
AUTOTUNE
)
return
data
return
input_fn
return
input_fn
...
@@ -230,7 +246,7 @@ def resnet_model_fn(features, labels, mode, model_class,
...
@@ -230,7 +246,7 @@ def resnet_model_fn(features, labels, mode, model_class,
# Generate a summary node for the images
# Generate a summary node for the images
tf
.
summary
.
image
(
'images'
,
features
,
max_outputs
=
6
)
tf
.
summary
.
image
(
'images'
,
features
,
max_outputs
=
6
)
# TODO(tobyboyd): Add cast as part of input pipeline on cpu and remove.
features
=
tf
.
cast
(
features
,
dtype
)
features
=
tf
.
cast
(
features
,
dtype
)
model
=
model_class
(
resnet_size
,
data_format
,
resnet_version
=
resnet_version
,
model
=
model_class
(
resnet_size
,
data_format
,
resnet_version
=
resnet_version
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment