Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ca509f79
Commit
ca509f79
authored
Feb 18, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 358289482
parent
a5e7e2ce
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
107 additions
and
59 deletions
+107
-59
official/core/input_reader.py
official/core/input_reader.py
+107
-59
No files found.
official/core/input_reader.py
View file @
ca509f79
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
"""A common dataset reader."""
"""A common dataset reader."""
import
random
import
random
from
typing
import
Any
,
Callable
,
Optional
from
typing
import
Any
,
Callable
,
List
,
Optional
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -27,6 +27,13 @@ def _get_random_integer():
...
@@ -27,6 +27,13 @@ def _get_random_integer():
return
random
.
randint
(
0
,
(
1
<<
31
)
-
1
)
return
random
.
randint
(
0
,
(
1
<<
31
)
-
1
)
def
_maybe_map_fn
(
dataset
:
tf
.
data
.
Dataset
,
fn
:
Optional
[
Callable
[...,
Any
]]
=
None
)
->
tf
.
data
.
Dataset
:
"""Calls dataset.map if a valid function is passed in."""
return
dataset
if
fn
is
None
else
dataset
.
map
(
fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
class
InputReader
:
class
InputReader
:
"""Input reader that returns a tf.data.Dataset instance."""
"""Input reader that returns a tf.data.Dataset instance."""
...
@@ -74,38 +81,7 @@ class InputReader:
...
@@ -74,38 +81,7 @@ class InputReader:
self
.
_tfds_builder
=
None
self
.
_tfds_builder
=
None
self
.
_matched_files
=
[]
self
.
_matched_files
=
[]
if
params
.
input_path
:
if
params
.
input_path
:
# Read dataset from files.
self
.
_matched_files
=
self
.
_match_files
(
params
.
input_path
)
usage
=
(
'`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s'
)
if
isinstance
(
params
.
input_path
,
str
):
input_path_list
=
[
params
.
input_path
]
elif
isinstance
(
params
.
input_path
,
(
list
,
tuple
)):
if
any
(
not
isinstance
(
x
,
str
)
for
x
in
params
.
input_path
):
raise
ValueError
(
usage
%
params
.
input_path
)
input_path_list
=
params
.
input_path
else
:
raise
ValueError
(
usage
%
params
.
input_path
)
for
input_path
in
input_path_list
:
input_patterns
=
input_path
.
strip
().
split
(
','
)
for
input_pattern
in
input_patterns
:
input_pattern
=
input_pattern
.
strip
()
if
not
input_pattern
:
continue
if
'*'
in
input_pattern
or
'?'
in
input_pattern
:
tmp_matched_files
=
tf
.
io
.
gfile
.
glob
(
input_pattern
)
if
not
tmp_matched_files
:
raise
ValueError
(
'%s does not match any files.'
%
input_pattern
)
self
.
_matched_files
.
extend
(
tmp_matched_files
)
else
:
self
.
_matched_files
.
append
(
input_pattern
)
if
not
self
.
_matched_files
:
raise
ValueError
(
'%s does not match any files.'
%
params
.
input_path
)
else
:
else
:
# Read dataset from TFDS.
# Read dataset from TFDS.
if
not
params
.
tfds_split
:
if
not
params
.
tfds_split
:
...
@@ -148,15 +124,57 @@ class InputReader:
...
@@ -148,15 +124,57 @@ class InputReader:
self
.
_enable_round_robin_tf_data_service
=
params
.
get
(
self
.
_enable_round_robin_tf_data_service
=
params
.
get
(
'enable_round_robin_tf_data_service'
,
False
)
'enable_round_robin_tf_data_service'
,
False
)
def
_match_files
(
self
,
input_path
:
str
)
->
List
[
str
]:
"""Matches files from an input_path."""
matched_files
=
[]
# Read dataset from files.
usage
=
(
'`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s'
)
if
isinstance
(
input_path
,
str
):
input_path_list
=
[
input_path
]
elif
isinstance
(
input_path
,
(
list
,
tuple
)):
if
any
(
not
isinstance
(
x
,
str
)
for
x
in
input_path
):
raise
ValueError
(
usage
%
input_path
)
input_path_list
=
input_path
else
:
raise
ValueError
(
usage
%
input_path
)
for
input_path
in
input_path_list
:
input_patterns
=
input_path
.
strip
().
split
(
','
)
for
input_pattern
in
input_patterns
:
input_pattern
=
input_pattern
.
strip
()
if
not
input_pattern
:
continue
if
'*'
in
input_pattern
or
'?'
in
input_pattern
:
tmp_matched_files
=
tf
.
io
.
gfile
.
glob
(
input_pattern
)
if
not
tmp_matched_files
:
raise
ValueError
(
'%s does not match any files.'
%
input_pattern
)
matched_files
.
extend
(
tmp_matched_files
)
else
:
matched_files
.
append
(
input_pattern
)
if
not
matched_files
:
raise
ValueError
(
'%s does not match any files.'
%
input_path
)
return
matched_files
def
_shard_files_then_read
(
def
_shard_files_then_read
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
self
,
matched_files
:
List
[
str
],
dataset_fn
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Shards the data files and then sent a split to every worker to read."""
"""Shards the data files and then sent a split to every worker to read."""
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
self
.
_
matched_files
)
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
matched_files
)
# Shuffle and repeat at file level.
# Shuffle and repeat at file level.
if
self
.
_is_training
:
if
self
.
_is_training
:
dataset
=
dataset
.
shuffle
(
dataset
=
dataset
.
shuffle
(
len
(
self
.
_
matched_files
),
len
(
matched_files
),
seed
=
self
.
_seed
,
seed
=
self
.
_seed
,
reshuffle_each_iteration
=
True
)
reshuffle_each_iteration
=
True
)
...
@@ -171,7 +189,7 @@ class InputReader:
...
@@ -171,7 +189,7 @@ class InputReader:
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
interleave
(
dataset
=
dataset
.
interleave
(
map_func
=
self
.
_
dataset_fn
,
map_func
=
dataset_fn
,
cycle_length
=
self
.
_cycle_length
,
cycle_length
=
self
.
_cycle_length
,
block_length
=
self
.
_block_length
,
block_length
=
self
.
_block_length
,
num_parallel_calls
=
(
self
.
_cycle_length
if
self
.
_cycle_length
else
num_parallel_calls
=
(
self
.
_cycle_length
if
self
.
_cycle_length
else
...
@@ -180,9 +198,13 @@ class InputReader:
...
@@ -180,9 +198,13 @@ class InputReader:
return
dataset
return
dataset
def
_read_files_then_shard
(
def
_read_files_then_shard
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
self
,
matched_files
:
List
[
str
],
dataset_fn
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Sends all data files to every worker and then shard by data."""
"""Sends all data files to every worker and then shard by data."""
dataset
=
self
.
_
dataset_fn
(
self
.
_
matched_files
)
dataset
=
dataset_fn
(
matched_files
)
# When `input_file` is a path to a single file or the number of files is
# When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding
# less than the number of input pipelines, disable auto sharding
...
@@ -238,26 +260,35 @@ class InputReader:
...
@@ -238,26 +260,35 @@ class InputReader:
raise
ValueError
(
'tfds_info is not available, because the dataset '
raise
ValueError
(
'tfds_info is not available, because the dataset '
'is not loaded from tfds.'
)
'is not loaded from tfds.'
)
def
read
(
def
_
read
_decode_and_parse_dataset
(
self
,
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
matched_files
:
List
[
str
],
)
->
tf
.
data
.
Dataset
:
dataset_fn
,
"""Generates a tf.data.Dataset object."""
batch_size
:
int
,
if
self
.
_tfds_builder
:
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
tfds_builder
:
bool
=
False
)
->
tf
.
data
.
Dataset
:
"""Returns a tf.data.Dataset object after reading, decoding, and parsing."""
if
tfds_builder
:
dataset
=
self
.
_read_tfds
(
input_context
)
dataset
=
self
.
_read_tfds
(
input_context
)
elif
len
(
self
.
_matched_files
)
>
1
:
elif
len
(
self
.
_matched_files
)
>
1
:
if
input_context
and
(
len
(
self
.
_
matched_files
)
<
if
input_context
and
(
len
(
matched_files
)
<
input_context
.
num_input_pipelines
):
input_context
.
num_input_pipelines
):
logging
.
warn
(
logging
.
warn
(
'The number of files %d is less than the number of input pipelines '
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.'
,
'Please consider sharding your data into more files.'
,
len
(
self
.
_matched_files
),
input_context
.
num_input_pipelines
)
len
(
matched_files
),
input_context
.
num_input_pipelines
)
dataset
=
self
.
_read_files_then_shard
(
input_context
)
dataset
=
self
.
_read_files_then_shard
(
matched_files
,
dataset_fn
,
input_context
)
else
:
else
:
dataset
=
self
.
_shard_files_then_read
(
input_context
)
dataset
=
self
.
_shard_files_then_read
(
matched_files
,
elif
len
(
self
.
_matched_files
)
==
1
:
dataset_fn
,
dataset
=
self
.
_read_files_then_shard
(
input_context
)
input_context
)
elif
len
(
matched_files
)
==
1
:
dataset
=
self
.
_read_files_then_shard
(
matched_files
,
dataset_fn
,
input_context
)
else
:
else
:
raise
ValueError
(
'It is unexpected that `tfds_builder` is None and '
raise
ValueError
(
'It is unexpected that `tfds_builder` is None and '
'there is also no `matched_files`.'
)
'there is also no `matched_files`.'
)
...
@@ -268,25 +299,28 @@ class InputReader:
...
@@ -268,25 +299,28 @@ class InputReader:
if
self
.
_is_training
:
if
self
.
_is_training
:
dataset
=
dataset
.
shuffle
(
self
.
_shuffle_buffer_size
)
dataset
=
dataset
.
shuffle
(
self
.
_shuffle_buffer_size
)
def
maybe_map_fn
(
dataset
,
fn
):
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_decoder_fn
)
return
dataset
if
fn
is
None
else
dataset
.
map
(
fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
maybe_map_fn
(
dataset
,
self
.
_decoder_fn
)
if
self
.
_sample_fn
is
not
None
:
if
self
.
_sample_fn
is
not
None
:
dataset
=
dataset
.
apply
(
self
.
_sample_fn
)
dataset
=
dataset
.
apply
(
self
.
_sample_fn
)
dataset
=
maybe_map_fn
(
dataset
,
self
.
_parser_fn
)
dataset
=
_
maybe_map_fn
(
dataset
,
self
.
_parser_fn
)
if
self
.
_transform_and_batch_fn
is
not
None
:
if
self
.
_transform_and_batch_fn
is
not
None
:
dataset
=
self
.
_transform_and_batch_fn
(
dataset
,
input_context
)
dataset
=
self
.
_transform_and_batch_fn
(
dataset
,
input_context
)
else
:
else
:
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
self
.
_global_
batch_size
)
if
input_context
else
self
.
_global_
batch_size
batch_size
)
if
input_context
else
batch_size
dataset
=
dataset
.
batch
(
dataset
=
dataset
.
batch
(
per_replica_batch_size
,
drop_remainder
=
self
.
_drop_remainder
)
per_replica_batch_size
,
drop_remainder
=
self
.
_drop_remainder
)
dataset
=
maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
return
dataset
def
_maybe_apply_data_service
(
self
,
dataset
:
tf
.
data
.
Dataset
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Potentially distributes a dataset."""
if
self
.
_enable_tf_data_service
and
input_context
:
if
self
.
_enable_tf_data_service
and
input_context
:
if
self
.
_enable_round_robin_tf_data_service
:
if
self
.
_enable_round_robin_tf_data_service
:
replicas_per_input_pipeline
=
input_context
.
num_replicas_in_sync
//
(
replicas_per_input_pipeline
=
input_context
.
num_replicas_in_sync
//
(
...
@@ -316,6 +350,20 @@ class InputReader:
...
@@ -316,6 +350,20 @@ class InputReader:
processing_mode
=
'parallel_epochs'
,
processing_mode
=
'parallel_epochs'
,
service
=
self
.
_tf_data_service_address
,
service
=
self
.
_tf_data_service_address
,
job_name
=
self
.
_tf_data_service_job_name
))
job_name
=
self
.
_tf_data_service_job_name
))
return
dataset
def
read
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Generates a tf.data.Dataset object."""
dataset
=
self
.
_read_decode_and_parse_dataset
(
self
.
_matched_files
,
self
.
_dataset_fn
,
self
.
_global_batch_size
,
input_context
,
self
.
_tfds_builder
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
dataset
=
self
.
_maybe_apply_data_service
(
dataset
,
input_context
)
if
self
.
_deterministic
is
not
None
:
if
self
.
_deterministic
is
not
None
:
options
=
tf
.
data
.
Options
()
options
=
tf
.
data
.
Options
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment