Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f762825c
Unverified
Commit
f762825c
authored
Oct 05, 2018
by
Toby Boyd
Committed by
GitHub
Oct 05, 2018
Browse files
Merge pull request #5443 from tfboyd/tf_data_autotune
Use AUTOTUNE, remove noop take, and comment fixes
parents
6801ea36
fe3746e6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
20 deletions
+10
-20
official/resnet/resnet_run_loop.py
official/resnet/resnet_run_loop.py
+10
-20
No files found.
official/resnet/resnet_run_loop.py
View file @
f762825c
...
@@ -68,36 +68,26 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
...
@@ -68,36 +68,26 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
Dataset of (image, label) pairs ready for iteration.
Dataset of (image, label) pairs ready for iteration.
"""
"""
# We prefetch a batch at a time, This can help smooth out the time taken to
# Sets tf.data to AUTOTUNE, e.g. num_parallel_batches in map_and_batch.
# load input files as we go through shuffling and processing.
options
=
tf
.
data
.
Options
()
options
.
experimental_autotune
=
True
dataset
=
dataset
.
with_options
(
options
)
# Prefetches a batch at a time to smooth out the time taken to load input
# files for shuffling and processing.
dataset
=
dataset
.
prefetch
(
buffer_size
=
batch_size
)
dataset
=
dataset
.
prefetch
(
buffer_size
=
batch_size
)
if
is_training
:
if
is_training
:
# Shuffle the records. Note that we shuffle before repeating to ensure
# Shuffles records before repeating to respect epoch boundaries.
# that the shuffling respects epoch boundaries.
dataset
=
dataset
.
shuffle
(
buffer_size
=
shuffle_buffer
)
dataset
=
dataset
.
shuffle
(
buffer_size
=
shuffle_buffer
)
# If we are training over multiple epochs before evaluating, repeat the
# Repeats the dataset for the number of epochs to train.
# dataset for the appropriate number of epochs.
dataset
=
dataset
.
repeat
(
num_epochs
)
dataset
=
dataset
.
repeat
(
num_epochs
)
if
is_training
and
num_gpus
and
examples_per_epoch
:
# Parses the raw records into images and labels.
total_examples
=
num_epochs
*
examples_per_epoch
# Force the number of batches to be divisible by the number of devices.
# This prevents some devices from receiving batches while others do not,
# which can lead to a lockup. This case will soon be handled directly by
# distribution strategies, at which point this .take() operation will no
# longer be needed.
total_batches
=
total_examples
//
batch_size
//
num_gpus
*
num_gpus
dataset
.
take
(
total_batches
*
batch_size
)
# Parse the raw records into images and labels. Testing has shown that setting
# num_parallel_batches > 1 produces no improvement in throughput, since
# batch_size is almost always much greater than the number of CPU cores.
dataset
=
dataset
.
apply
(
dataset
=
dataset
.
apply
(
tf
.
contrib
.
data
.
map_and_batch
(
tf
.
contrib
.
data
.
map_and_batch
(
lambda
value
:
parse_record_fn
(
value
,
is_training
,
dtype
),
lambda
value
:
parse_record_fn
(
value
,
is_training
,
dtype
),
batch_size
=
batch_size
,
batch_size
=
batch_size
,
num_parallel_batches
=
1
,
drop_remainder
=
False
))
drop_remainder
=
False
))
# Operations between the final prefetch and the get_next call to the iterator
# Operations between the final prefetch and the get_next call to the iterator
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment