Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7b304676
Unverified
Commit
7b304676
authored
Sep 05, 2018
by
Toby Boyd
Committed by
GitHub
Sep 05, 2018
Browse files
Merge pull request #5253 from tfboyd/resnet_input_pipeline_fp16
Move tf.cast to tf.float16 in input pipeline
parents
7babedc5
5c0c749b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
15 deletions
+27
-15
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+8
-4
official/resnet/cifar10_test.py
official/resnet/cifar10_test.py
+1
-1
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+8
-3
official/resnet/resnet_run_loop.py
official/resnet/resnet_run_loop.py
+10
-7
No files found.
official/resnet/cifar10_main.py
View file @
7b304676
...
@@ -66,7 +66,7 @@ def get_filenames(is_training, data_dir):
...
@@ -66,7 +66,7 @@ def get_filenames(is_training, data_dir):
return
[
os
.
path
.
join
(
data_dir
,
'test_batch.bin'
)]
return
[
os
.
path
.
join
(
data_dir
,
'test_batch.bin'
)]
def
parse_record
(
raw_record
,
is_training
):
def
parse_record
(
raw_record
,
is_training
,
dtype
):
"""Parse CIFAR-10 image and label from a raw record."""
"""Parse CIFAR-10 image and label from a raw record."""
# Convert bytes to a vector of uint8 that is record_bytes long.
# Convert bytes to a vector of uint8 that is record_bytes long.
record_vector
=
tf
.
decode_raw
(
raw_record
,
tf
.
uint8
)
record_vector
=
tf
.
decode_raw
(
raw_record
,
tf
.
uint8
)
...
@@ -85,6 +85,7 @@ def parse_record(raw_record, is_training):
...
@@ -85,6 +85,7 @@ def parse_record(raw_record, is_training):
image
=
tf
.
cast
(
tf
.
transpose
(
depth_major
,
[
1
,
2
,
0
]),
tf
.
float32
)
image
=
tf
.
cast
(
tf
.
transpose
(
depth_major
,
[
1
,
2
,
0
]),
tf
.
float32
)
image
=
preprocess_image
(
image
,
is_training
)
image
=
preprocess_image
(
image
,
is_training
)
image
=
tf
.
cast
(
image
,
dtype
)
return
image
,
label
return
image
,
label
...
@@ -107,8 +108,9 @@ def preprocess_image(image, is_training):
...
@@ -107,8 +108,9 @@ def preprocess_image(image, is_training):
return
image
return
image
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
,
num_gpus
=
None
):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
,
num_gpus
=
None
,
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
dtype
=
tf
.
float32
):
"""Input function which provides batches for train or eval.
Args:
Args:
is_training: A boolean denoting whether the input is for training.
is_training: A boolean denoting whether the input is for training.
...
@@ -116,6 +118,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
...
@@ -116,6 +118,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
batch_size: The number of samples per batch.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
num_gpus: The number of gpus used for training.
dtype: Data type to use for images/features
Returns:
Returns:
A dataset that can be used for iteration.
A dataset that can be used for iteration.
...
@@ -131,7 +134,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
...
@@ -131,7 +134,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
parse_record_fn
=
parse_record
,
parse_record_fn
=
parse_record
,
num_epochs
=
num_epochs
,
num_epochs
=
num_epochs
,
num_gpus
=
num_gpus
,
num_gpus
=
num_gpus
,
examples_per_epoch
=
_NUM_IMAGES
[
'train'
]
if
is_training
else
None
examples_per_epoch
=
_NUM_IMAGES
[
'train'
]
if
is_training
else
None
,
dtype
=
dtype
)
)
...
...
official/resnet/cifar10_test.py
View file @
7b304676
...
@@ -61,7 +61,7 @@ class BaseTest(tf.test.TestCase):
...
@@ -61,7 +61,7 @@ class BaseTest(tf.test.TestCase):
fake_dataset
=
tf
.
data
.
FixedLengthRecordDataset
(
fake_dataset
=
tf
.
data
.
FixedLengthRecordDataset
(
filename
,
cifar10_main
.
_RECORD_BYTES
)
# pylint: disable=protected-access
filename
,
cifar10_main
.
_RECORD_BYTES
)
# pylint: disable=protected-access
fake_dataset
=
fake_dataset
.
map
(
fake_dataset
=
fake_dataset
.
map
(
lambda
val
:
cifar10_main
.
parse_record
(
val
,
False
))
lambda
val
:
cifar10_main
.
parse_record
(
val
,
False
,
tf
.
float32
))
image
,
label
=
fake_dataset
.
make_one_shot_iterator
().
get_next
()
image
,
label
=
fake_dataset
.
make_one_shot_iterator
().
get_next
()
self
.
assertAllEqual
(
label
.
shape
,
())
self
.
assertAllEqual
(
label
.
shape
,
())
...
...
official/resnet/imagenet_main.py
View file @
7b304676
...
@@ -129,7 +129,7 @@ def _parse_example_proto(example_serialized):
...
@@ -129,7 +129,7 @@ def _parse_example_proto(example_serialized):
return
features
[
'image/encoded'
],
label
,
bbox
return
features
[
'image/encoded'
],
label
,
bbox
def
parse_record
(
raw_record
,
is_training
):
def
parse_record
(
raw_record
,
is_training
,
dtype
):
"""Parses a record containing a training example of an image.
"""Parses a record containing a training example of an image.
The input record is parsed into a label and image, and the image is passed
The input record is parsed into a label and image, and the image is passed
...
@@ -139,6 +139,7 @@ def parse_record(raw_record, is_training):
...
@@ -139,6 +139,7 @@ def parse_record(raw_record, is_training):
raw_record: scalar Tensor tf.string containing a serialized
raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer.
Example protocol buffer.
is_training: A boolean denoting whether the input is for training.
is_training: A boolean denoting whether the input is for training.
dtype: data type to use for images/features.
Returns:
Returns:
Tuple with processed image tensor and one-hot-encoded label tensor.
Tuple with processed image tensor and one-hot-encoded label tensor.
...
@@ -152,11 +153,13 @@ def parse_record(raw_record, is_training):
...
@@ -152,11 +153,13 @@ def parse_record(raw_record, is_training):
output_width
=
_DEFAULT_IMAGE_SIZE
,
output_width
=
_DEFAULT_IMAGE_SIZE
,
num_channels
=
_NUM_CHANNELS
,
num_channels
=
_NUM_CHANNELS
,
is_training
=
is_training
)
is_training
=
is_training
)
image
=
tf
.
cast
(
image
,
dtype
)
return
image
,
label
return
image
,
label
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
,
num_gpus
=
None
):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
,
num_gpus
=
None
,
dtype
=
tf
.
float32
):
"""Input function which provides batches for train or eval.
"""Input function which provides batches for train or eval.
Args:
Args:
...
@@ -165,6 +168,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
...
@@ -165,6 +168,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
batch_size: The number of samples per batch.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
num_gpus: The number of gpus used for training.
dtype: Data type to use for images/features
Returns:
Returns:
A dataset that can be used for iteration.
A dataset that can be used for iteration.
...
@@ -192,7 +196,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
...
@@ -192,7 +196,8 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, num_gpus=None):
parse_record_fn
=
parse_record
,
parse_record_fn
=
parse_record
,
num_epochs
=
num_epochs
,
num_epochs
=
num_epochs
,
num_gpus
=
num_gpus
,
num_gpus
=
num_gpus
,
examples_per_epoch
=
_NUM_IMAGES
[
'train'
]
if
is_training
else
None
examples_per_epoch
=
_NUM_IMAGES
[
'train'
]
if
is_training
else
None
,
dtype
=
dtype
)
)
...
...
official/resnet/resnet_run_loop.py
View file @
7b304676
...
@@ -45,7 +45,7 @@ from official.utils.misc import model_helpers
...
@@ -45,7 +45,7 @@ from official.utils.misc import model_helpers
################################################################################
################################################################################
def
process_record_dataset
(
dataset
,
is_training
,
batch_size
,
shuffle_buffer
,
def
process_record_dataset
(
dataset
,
is_training
,
batch_size
,
shuffle_buffer
,
parse_record_fn
,
num_epochs
=
1
,
num_gpus
=
None
,
parse_record_fn
,
num_epochs
=
1
,
num_gpus
=
None
,
examples_per_epoch
=
None
):
examples_per_epoch
=
None
,
dtype
=
tf
.
float32
):
"""Given a Dataset with raw records, return an iterator over the records.
"""Given a Dataset with raw records, return an iterator over the records.
Args:
Args:
...
@@ -60,6 +60,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
...
@@ -60,6 +60,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
num_epochs: The number of epochs to repeat the dataset.
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
num_gpus: The number of gpus used for training.
examples_per_epoch: The number of examples in an epoch.
examples_per_epoch: The number of examples in an epoch.
dtype: Data type to use for images/features.
Returns:
Returns:
Dataset of (image, label) pairs ready for iteration.
Dataset of (image, label) pairs ready for iteration.
...
@@ -92,7 +93,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
...
@@ -92,7 +93,7 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
# batch_size is almost always much greater than the number of CPU cores.
# batch_size is almost always much greater than the number of CPU cores.
dataset
=
dataset
.
apply
(
dataset
=
dataset
.
apply
(
tf
.
contrib
.
data
.
map_and_batch
(
tf
.
contrib
.
data
.
map_and_batch
(
lambda
value
:
parse_record_fn
(
value
,
is_training
),
lambda
value
:
parse_record_fn
(
value
,
is_training
,
dtype
),
batch_size
=
batch_size
,
batch_size
=
batch_size
,
num_parallel_batches
=
1
,
num_parallel_batches
=
1
,
drop_remainder
=
False
))
drop_remainder
=
False
))
...
@@ -248,8 +249,8 @@ def resnet_model_fn(features, labels, mode, model_class,
...
@@ -248,8 +249,8 @@ def resnet_model_fn(features, labels, mode, model_class,
# Generate a summary node for the images
# Generate a summary node for the images
tf
.
summary
.
image
(
'images'
,
features
,
max_outputs
=
6
)
tf
.
summary
.
image
(
'images'
,
features
,
max_outputs
=
6
)
#
TODO(tobyboyd): Add cast as part of input pipeline on cpu and remove
.
#
Checks that features/images have same data type being used for calculations
.
features
=
tf
.
cast
(
features
,
dtype
)
assert
features
.
dtype
==
dtype
model
=
model_class
(
resnet_size
,
data_format
,
resnet_version
=
resnet_version
,
model
=
model_class
(
resnet_size
,
data_format
,
resnet_version
=
resnet_version
,
dtype
=
dtype
)
dtype
=
dtype
)
...
@@ -454,14 +455,16 @@ def resnet_main(
...
@@ -454,14 +455,16 @@ def resnet_main(
batch_size
=
distribution_utils
.
per_device_batch_size
(
batch_size
=
distribution_utils
.
per_device_batch_size
(
flags_obj
.
batch_size
,
flags_core
.
get_num_gpus
(
flags_obj
)),
flags_obj
.
batch_size
,
flags_core
.
get_num_gpus
(
flags_obj
)),
num_epochs
=
num_epochs
,
num_epochs
=
num_epochs
,
num_gpus
=
flags_core
.
get_num_gpus
(
flags_obj
))
num_gpus
=
flags_core
.
get_num_gpus
(
flags_obj
),
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
))
def
input_fn_eval
():
def
input_fn_eval
():
return
input_function
(
return
input_function
(
is_training
=
False
,
data_dir
=
flags_obj
.
data_dir
,
is_training
=
False
,
data_dir
=
flags_obj
.
data_dir
,
batch_size
=
distribution_utils
.
per_device_batch_size
(
batch_size
=
distribution_utils
.
per_device_batch_size
(
flags_obj
.
batch_size
,
flags_core
.
get_num_gpus
(
flags_obj
)),
flags_obj
.
batch_size
,
flags_core
.
get_num_gpus
(
flags_obj
)),
num_epochs
=
1
)
num_epochs
=
1
,
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
))
if
flags_obj
.
eval_only
or
not
flags_obj
.
train_epochs
:
if
flags_obj
.
eval_only
or
not
flags_obj
.
train_epochs
:
# If --eval_only is set, perform a single loop with zero train epochs.
# If --eval_only is set, perform a single loop with zero train epochs.
...
@@ -533,7 +536,7 @@ def define_resnet_flags(resnet_size_choices=None):
...
@@ -533,7 +536,7 @@ def define_resnet_flags(resnet_size_choices=None):
'If not None initialize all the network except the final layer with '
'If not None initialize all the network except the final layer with '
'these values'
))
'these values'
))
flags
.
DEFINE_boolean
(
flags
.
DEFINE_boolean
(
name
=
"
eval_only
"
,
default
=
False
,
name
=
'
eval_only
'
,
default
=
False
,
help
=
flags_core
.
help_wrap
(
'Skip training and only perform evaluation on '
help
=
flags_core
.
help_wrap
(
'Skip training and only perform evaluation on '
'the latest checkpoint.'
))
'the latest checkpoint.'
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment