Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5a3af75c
Commit
5a3af75c
authored
Jun 10, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Jun 10, 2020
Browse files
Support to read a dataset from TFDS.
PiperOrigin-RevId: 315774221
parent
8a1dbbad
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
4 deletions
+81
-4
official/core/input_reader.py
official/core/input_reader.py
+56
-3
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+25
-1
No files found.
official/core/input_reader.py
View file @
5a3af75c
...
...
@@ -18,6 +18,7 @@
from
typing
import
Any
,
Callable
,
List
,
Optional
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
from
official.modeling.hyperparams
import
config_definitions
as
cfg
...
...
@@ -53,11 +54,15 @@ class InputReader:
postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching.
"""
# TODO(chendouble): Support TFDS as input_path.
if
params
.
input_path
and
params
.
tfds_name
:
raise
ValueError
(
'At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.'
%
(
params
.
input_path
,
params
.
tfds_name
))
self
.
_shards
=
shards
self
.
_tfds_builder
=
None
if
self
.
_shards
:
self
.
_num_files
=
len
(
self
.
_shards
)
el
s
e
:
el
if
not
params
.
tfds_nam
e
:
self
.
_input_patterns
=
params
.
input_path
.
strip
().
split
(
','
)
self
.
_num_files
=
0
for
input_pattern
in
self
.
_input_patterns
:
...
...
@@ -71,6 +76,13 @@ class InputReader:
self
.
_num_files
+=
len
(
matched_files
)
if
self
.
_num_files
==
0
:
raise
ValueError
(
'%s does not match any files.'
%
params
.
input_path
)
else
:
if
not
params
.
tfds_split
:
raise
ValueError
(
'`tfds_name` is %s, but `tfds_split` is not specified.'
%
params
.
tfds_name
)
self
.
_tfds_builder
=
tfds
.
builder
(
params
.
tfds_name
,
data_dir
=
params
.
tfds_data_dir
)
self
.
_global_batch_size
=
params
.
global_batch_size
self
.
_is_training
=
params
.
is_training
...
...
@@ -78,8 +90,13 @@ class InputReader:
self
.
_shuffle_buffer_size
=
params
.
shuffle_buffer_size
self
.
_cache
=
params
.
cache
self
.
_cycle_length
=
params
.
cycle_length
self
.
_block_length
=
params
.
block_length
self
.
_sharding
=
params
.
sharding
self
.
_examples_consume
=
params
.
examples_consume
self
.
_tfds_split
=
params
.
tfds_split
self
.
_tfds_download
=
params
.
tfds_download
self
.
_tfds_as_supervised
=
params
.
tfds_as_supervised
self
.
_tfds_skip_decoding_feature
=
params
.
tfds_skip_decoding_feature
self
.
_dataset_fn
=
dataset_fn
self
.
_decoder_fn
=
decoder_fn
...
...
@@ -107,6 +124,7 @@ class InputReader:
dataset
=
dataset
.
interleave
(
map_func
=
self
.
_dataset_fn
,
cycle_length
=
self
.
_cycle_length
,
block_length
=
self
.
_block_length
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
...
...
@@ -131,12 +149,47 @@ class InputReader:
dataset
=
dataset
.
repeat
()
return
dataset
def
_read_tfds
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Reads a dataset from tfds."""
if
self
.
_tfds_download
:
self
.
_tfds_builder
.
download_and_prepare
()
read_config
=
tfds
.
ReadConfig
(
interleave_cycle_length
=
self
.
_cycle_length
,
interleave_block_length
=
self
.
_block_length
,
input_context
=
input_context
)
decoders
=
{}
if
self
.
_tfds_skip_decoding_feature
:
for
skip_feature
in
self
.
_tfds_skip_decoding_feature
.
split
(
','
):
decoders
[
skip_feature
.
strip
()]
=
tfds
.
decode
.
SkipDecoding
()
dataset
=
self
.
_tfds_builder
.
as_dataset
(
split
=
self
.
_tfds_split
,
shuffle_files
=
self
.
_is_training
,
as_supervised
=
self
.
_tfds_as_supervised
,
decoders
=
decoders
,
read_config
=
read_config
)
return
dataset
@
property
def
tfds_info
(
self
)
->
tfds
.
core
.
DatasetInfo
:
"""Returns TFDS dataset info, if available."""
if
self
.
_tfds_builder
:
return
self
.
_tfds_builder
.
info
else
:
raise
ValueError
(
'tfds_info is not available, because the dataset '
'is not loaded from tfds.'
)
def
read
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Generates a tf.data.Dataset object."""
if
self
.
_num_files
>
1
:
if
self
.
_tfds_builder
:
dataset
=
self
.
_read_tfds
(
input_context
)
elif
self
.
_num_files
>
1
:
dataset
=
self
.
_read_sharded_files
(
input_context
)
else
:
assert
self
.
_num_files
==
1
...
...
official/modeling/hyperparams/config_definitions.py
View file @
5a3af75c
...
...
@@ -31,7 +31,12 @@ class DataConfig(base_config.Config):
Attributes:
input_path: The path to the input. It can be either (1) a file pattern, or
(2) multiple file patterns separated by comma.
(2) multiple file patterns separated by comma. It should not be specified
when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
is required when above `tfds_name` is specified.
global_batch_size: The global batch size across all replicas.
is_training: Whether this data is used for training or not.
drop_remainder: Whether the last batch should be dropped in the case it has
...
...
@@ -41,21 +46,40 @@ class DataConfig(base_config.Config):
from disk on the second epoch. Requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
sharding: Whether sharding is used in the input pipeline.
examples_consume: An `integer` specifying the number of examples it will
produce. If positive, it only takes this number of examples and raises
tf.error.OutOfRangeError after that. Default is -1, meaning it will
exhaust all the examples in the dataset.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
the returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default,
the returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped
for decoding when loading dataset from TFDS. Use comma to separate
multiple features. The main use case is to skip the image/video decoding
for better performance.
"""
input_path
:
str
=
""
tfds_name
:
str
=
""
tfds_split
:
str
=
""
global_batch_size
:
int
=
0
is_training
:
bool
=
None
drop_remainder
:
bool
=
True
shuffle_buffer_size
:
int
=
100
cache
:
bool
=
False
cycle_length
:
int
=
8
block_length
:
int
=
1
sharding
:
bool
=
True
examples_consume
:
int
=
-
1
tfds_data_dir
:
str
=
""
tfds_download
:
bool
=
False
tfds_as_supervised
:
bool
=
False
tfds_skip_decoding_feature
:
str
=
""
@
dataclasses
.
dataclass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment