Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c6479e77
Commit
c6479e77
authored
Nov 06, 2017
by
Neal Wu
Browse files
Ensure that shuffle occurs before map
parent
6e52c271
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
10 additions
and
8 deletions
+10
-8
official/resnet/cifar10_main.py
official/resnet/cifar10_main.py
+3
-4
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+4
-3
official/wide_deep/wide_deep.py
official/wide_deep/wide_deep.py
+3
-1
No files found.
official/resnet/cifar10_main.py
View file @
c6479e77
...
@@ -71,8 +71,6 @@ _NUM_IMAGES = {
...
@@ -71,8 +71,6 @@ _NUM_IMAGES = {
'validation'
:
10000
,
'validation'
:
10000
,
}
}
_SHUFFLE_BUFFER
=
20000
def
record_dataset
(
filenames
):
def
record_dataset
(
filenames
):
"""Returns an input pipeline Dataset from `filenames`."""
"""Returns an input pipeline Dataset from `filenames`."""
...
@@ -158,8 +156,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
...
@@ -158,8 +156,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
if
is_training
:
if
is_training
:
# When choosing shuffle buffer sizes, larger sizes result in better
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
# randomness, while smaller sizes have better performance. Because CIFAR-10
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
# is a relatively small dataset, we choose to shuffle the full epoch.
dataset
=
dataset
.
shuffle
(
buffer_size
=
_NUM_IMAGES
[
'train'
])
dataset
=
dataset
.
map
(
parse_record
)
dataset
=
dataset
.
map
(
parse_record
)
dataset
=
dataset
.
map
(
dataset
=
dataset
.
map
(
...
...
official/resnet/imagenet_main.py
View file @
c6479e77
...
@@ -142,14 +142,15 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
...
@@ -142,14 +142,15 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
dataset
=
dataset
.
flat_map
(
tf
.
data
.
TFRecordDataset
)
dataset
=
dataset
.
flat_map
(
tf
.
data
.
TFRecordDataset
)
dataset
=
dataset
.
map
(
lambda
value
:
dataset_parser
(
value
,
is_training
),
num_parallel_calls
=
5
).
prefetch
(
batch_size
)
if
is_training
:
if
is_training
:
# When choosing shuffle buffer sizes, larger sizes result in better
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
# randomness, while smaller sizes have better performance.
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
dataset
=
dataset
.
map
(
lambda
value
:
dataset_parser
(
value
,
is_training
),
num_parallel_calls
=
5
)
dataset
=
dataset
.
prefetch
(
batch_size
)
# We call repeat after shuffling, rather than before, to prevent separate
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
# epochs from blending together.
dataset
=
dataset
.
repeat
(
num_epochs
)
dataset
=
dataset
.
repeat
(
num_epochs
)
...
...
official/wide_deep/wide_deep.py
View file @
c6479e77
...
@@ -179,11 +179,12 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
...
@@ -179,11 +179,12 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
# Extract lines from input files using the Dataset API.
# Extract lines from input files using the Dataset API.
dataset
=
tf
.
data
.
TextLineDataset
(
data_file
)
dataset
=
tf
.
data
.
TextLineDataset
(
data_file
)
dataset
=
dataset
.
map
(
parse_csv
,
num_parallel_calls
=
5
)
if
shuffle
:
if
shuffle
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
dataset
=
dataset
.
map
(
parse_csv
,
num_parallel_calls
=
5
)
# We call repeat after shuffling, rather than before, to prevent separate
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
# epochs from blending together.
dataset
=
dataset
.
repeat
(
num_epochs
)
dataset
=
dataset
.
repeat
(
num_epochs
)
...
@@ -193,6 +194,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
...
@@ -193,6 +194,7 @@ def input_fn(data_file, num_epochs, shuffle, batch_size):
features
,
labels
=
iterator
.
get_next
()
features
,
labels
=
iterator
.
get_next
()
return
features
,
labels
return
features
,
labels
def
main
(
unused_argv
):
def
main
(
unused_argv
):
# Clean up the model directory if present
# Clean up the model directory if present
shutil
.
rmtree
(
FLAGS
.
model_dir
,
ignore_errors
=
True
)
shutil
.
rmtree
(
FLAGS
.
model_dir
,
ignore_errors
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment