Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3ae33b4d
Unverified
Commit
3ae33b4d
authored
Jun 06, 2018
by
Taylor Robie
Committed by
GitHub
Jun 06, 2018
Browse files
Fix/resnet take (#4473)
* add .take() to dataset pipeline * delint * address PR comments
parent
441c9bca
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
8 deletions
+37
-8
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+10
-3
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+10
-3
official/resnet/resnet_run_loop.py
official/resnet/resnet_run_loop.py
+17
-2
No files found.
official/resnet/cifar10_main.py
View file @
3ae33b4d
...
...
@@ -107,7 +107,7 @@ def preprocess_image(image, is_training):
return
image
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
,
num_gpus
=
None
):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args:
...
...
@@ -115,6 +115,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
Returns:
A dataset that can be used for iteration.
...
...
@@ -123,8 +124,14 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
dataset
=
tf
.
data
.
FixedLengthRecordDataset
(
filenames
,
_RECORD_BYTES
)
return
resnet_run_loop
.
process_record_dataset
(
dataset
,
is_training
,
batch_size
,
_NUM_IMAGES
[
'train'
],
parse_record
,
num_epochs
,
dataset
=
dataset
,
is_training
=
is_training
,
batch_size
=
batch_size
,
shuffle_buffer
=
_NUM_IMAGES
[
'train'
],
parse_record_fn
=
parse_record
,
num_epochs
=
num_epochs
,
num_gpus
=
num_gpus
,
examples_per_epoch
=
_NUM_IMAGES
[
'train'
]
if
is_training
else
None
)
...
...
official/resnet/imagenet_main.py
View file @
3ae33b4d
...
...
@@ -156,7 +156,7 @@ def parse_record(raw_record, is_training):
return
image
,
label
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
,
num_gpus
=
None
):
"""Input function which provides batches for train or eval.
Args:
...
...
@@ -164,6 +164,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
Returns:
A dataset that can be used for iteration.
...
...
@@ -184,8 +185,14 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
tf
.
data
.
TFRecordDataset
,
cycle_length
=
10
))
return
resnet_run_loop
.
process_record_dataset
(
dataset
,
is_training
,
batch_size
,
_SHUFFLE_BUFFER
,
parse_record
,
num_epochs
dataset
=
dataset
,
is_training
=
is_training
,
batch_size
=
batch_size
,
shuffle_buffer
=
_SHUFFLE_BUFFER
,
parse_record_fn
=
parse_record
,
num_epochs
=
num_epochs
,
num_gpus
=
num_gpus
,
examples_per_epoch
=
_NUM_IMAGES
[
'train'
]
if
is_training
else
None
)
...
...
official/resnet/resnet_run_loop.py
View file @
3ae33b4d
...
...
@@ -42,7 +42,8 @@ from official.utils.misc import model_helpers
# Functions for input processing.
################################################################################
def
process_record_dataset
(
dataset
,
is_training
,
batch_size
,
shuffle_buffer
,
parse_record_fn
,
num_epochs
=
1
):
parse_record_fn
,
num_epochs
=
1
,
num_gpus
=
None
,
examples_per_epoch
=
None
):
"""Given a Dataset with raw records, return an iterator over the records.
Args:
...
...
@@ -55,6 +56,8 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
parse_record_fn: A function that takes a raw record and returns the
corresponding (image, label) pair.
num_epochs: The number of epochs to repeat the dataset.
num_gpus: The number of gpus used for training.
examples_per_epoch: The number of examples in an epoch.
Returns:
Dataset of (image, label) pairs ready for iteration.
...
...
@@ -72,6 +75,16 @@ def process_record_dataset(dataset, is_training, batch_size, shuffle_buffer,
# dataset for the appropriate number of epochs.
dataset
=
dataset
.
repeat
(
num_epochs
)
if
is_training
and
num_gpus
and
examples_per_epoch
:
total_examples
=
num_epochs
*
examples_per_epoch
# Force the number of batches to be divisible by the number of devices.
# This prevents some devices from receiving batches while others do not,
# which can lead to a lockup. This case will soon be handled directly by
# distribution strategies, at which point this .take() operation will no
# longer be needed.
total_batches
=
total_examples
//
batch_size
//
num_gpus
*
num_gpus
dataset
.
take
(
total_batches
*
batch_size
)
# Parse the raw records into images and labels. Testing has shown that setting
# num_parallel_batches > 1 produces no improvement in throughput, since
# batch_size is almost always much greater than the number of CPU cores.
...
...
@@ -411,7 +424,8 @@ def resnet_main(
is_training
=
True
,
data_dir
=
flags_obj
.
data_dir
,
batch_size
=
per_device_batch_size
(
flags_obj
.
batch_size
,
flags_core
.
get_num_gpus
(
flags_obj
)),
num_epochs
=
flags_obj
.
epochs_between_evals
)
num_epochs
=
flags_obj
.
epochs_between_evals
,
num_gpus
=
flags_core
.
get_num_gpus
(
flags_obj
))
def
input_fn_eval
():
return
input_function
(
...
...
@@ -419,6 +433,7 @@ def resnet_main(
batch_size
=
per_device_batch_size
(
flags_obj
.
batch_size
,
flags_core
.
get_num_gpus
(
flags_obj
)),
num_epochs
=
1
)
total_training_cycle
=
(
flags_obj
.
train_epochs
//
flags_obj
.
epochs_between_evals
)
for
cycle_index
in
range
(
total_training_cycle
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment