Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6a2de9bb
Commit
6a2de9bb
authored
Jul 21, 2020
by
Bruce Fontaine
Committed by
A. Unique TensorFlower
Jul 21, 2020
Browse files
Fix NCF input pipeline to avoid reading the same file multiple times in one epoch.
PiperOrigin-RevId: 322415899
parent
f97e0231
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
25 deletions
+14
-25
official/recommendation/ncf_input_pipeline.py
official/recommendation/ncf_input_pipeline.py
+14
-25
No files found.
official/recommendation/ncf_input_pipeline.py
View file @
6a2de9bb
...
@@ -25,10 +25,8 @@ import tensorflow.compat.v2 as tf
...
@@ -25,10 +25,8 @@ import tensorflow.compat.v2 as tf
# pylint: enable=g-bad-import-order
# pylint: enable=g-bad-import-order
from
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
constants
as
rconst
from
official.recommendation
import
movielens
from
official.recommendation
import
data_pipeline
from
official.recommendation
import
data_pipeline
from
official.recommendation
import
movielens
NUM_SHARDS
=
16
def
create_dataset_from_tf_record_files
(
input_file_pattern
,
def
create_dataset_from_tf_record_files
(
input_file_pattern
,
...
@@ -36,17 +34,15 @@ def create_dataset_from_tf_record_files(input_file_pattern,
...
@@ -36,17 +34,15 @@ def create_dataset_from_tf_record_files(input_file_pattern,
batch_size
,
batch_size
,
is_training
=
True
):
is_training
=
True
):
"""Creates dataset from (tf)records files for training/evaluation."""
"""Creates dataset from (tf)records files for training/evaluation."""
files
=
tf
.
data
.
Dataset
.
list_files
(
input_file_pattern
,
shuffle
=
is_training
)
def
make_dataset
(
files_dataset
,
shard_index
):
"""Returns dataset for sharded tf record files."""
if
pre_batch_size
!=
batch_size
:
if
pre_batch_size
!=
batch_size
:
raise
ValueError
(
"Pre-batch ({}) size is not equal to batch "
raise
ValueError
(
"Pre-batch ({}) size is not equal to batch "
"size ({})"
.
format
(
pre_batch_size
,
batch_size
))
"size ({})"
.
format
(
pre_batch_size
,
batch_size
))
files_dataset
=
files_dataset
.
shard
(
NUM_SHARDS
,
shard_index
)
dataset
=
files_dataset
.
interleave
(
files
=
tf
.
data
.
Dataset
.
list_files
(
input_file_pattern
,
shuffle
=
is_training
)
dataset
=
files
.
interleave
(
tf
.
data
.
TFRecordDataset
,
tf
.
data
.
TFRecordDataset
,
cycle_length
=
16
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
decode_fn
=
functools
.
partial
(
decode_fn
=
functools
.
partial
(
data_pipeline
.
DatasetManager
.
deserialize
,
data_pipeline
.
DatasetManager
.
deserialize
,
...
@@ -54,14 +50,7 @@ def create_dataset_from_tf_record_files(input_file_pattern,
...
@@ -54,14 +50,7 @@ def create_dataset_from_tf_record_files(input_file_pattern,
is_training
=
is_training
)
is_training
=
is_training
)
dataset
=
dataset
.
map
(
dataset
=
dataset
.
map
(
decode_fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
decode_fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
dataset
=
tf
.
data
.
Dataset
.
range
(
NUM_SHARDS
)
map_fn
=
functools
.
partial
(
make_dataset
,
files
)
dataset
=
dataset
.
interleave
(
map_fn
,
cycle_length
=
NUM_SHARDS
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
dataset
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment