Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5f0776a2
Unverified
Commit
5f0776a2
authored
Nov 07, 2017
by
Neal Wu
Committed by
GitHub
Nov 07, 2017
Browse files
Move dataset.map back to before dataset.shuffle in imagenet_main.py (#2731)
parent
21b48a85
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
8 deletions
+6
-8
official/resnet/imagenet_main.py
official/resnet/imagenet_main.py
+6
-8
No files found.
official/resnet/imagenet_main.py
View file @
5f0776a2
...
@@ -89,8 +89,8 @@ def filenames(is_training, data_dir):
...
@@ -89,8 +89,8 @@ def filenames(is_training, data_dir):
for
i
in
range
(
128
)]
for
i
in
range
(
128
)]
def
dataset
_parser
(
value
,
is_training
):
def
record
_parser
(
value
,
is_training
):
"""Parse an Image
n
et record from value."""
"""Parse an Image
N
et record from
`
value
`
."""
keys_to_features
=
{
keys_to_features
=
{
'image/encoded'
:
'image/encoded'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
...
@@ -134,23 +134,21 @@ def dataset_parser(value, is_training):
...
@@ -134,23 +134,21 @@ def dataset_parser(value, is_training):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
):
def
input_fn
(
is_training
,
data_dir
,
batch_size
,
num_epochs
=
1
):
"""Input function which provides batches for train or eval."""
"""Input function which provides batches for train or eval."""
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
filenames
(
is_training
,
data_dir
))
filenames
(
is_training
,
data_dir
))
if
is_training
:
if
is_training
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
_FILE_SHUFFLE_BUFFER
)
dataset
=
dataset
.
shuffle
(
buffer_size
=
_FILE_SHUFFLE_BUFFER
)
dataset
=
dataset
.
flat_map
(
tf
.
data
.
TFRecordDataset
)
dataset
=
dataset
.
flat_map
(
tf
.
data
.
TFRecordDataset
)
dataset
=
dataset
.
map
(
lambda
value
:
record_parser
(
value
,
is_training
),
num_parallel_calls
=
5
)
dataset
=
dataset
.
prefetch
(
batch_size
)
if
is_training
:
if
is_training
:
# When choosing shuffle buffer sizes, larger sizes result in better
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance.
# randomness, while smaller sizes have better performance.
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
dataset
=
dataset
.
shuffle
(
buffer_size
=
_SHUFFLE_BUFFER
)
dataset
=
dataset
.
map
(
lambda
value
:
dataset_parser
(
value
,
is_training
),
num_parallel_calls
=
5
)
dataset
=
dataset
.
prefetch
(
batch_size
)
# We call repeat after shuffling, rather than before, to prevent separate
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
# epochs from blending together.
dataset
=
dataset
.
repeat
(
num_epochs
)
dataset
=
dataset
.
repeat
(
num_epochs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment