Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5a3af75c
Commit
5a3af75c
authored
Jun 10, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Jun 10, 2020
Browse files
Support to read a dataset from TFDS.
PiperOrigin-RevId: 315774221
parent
8a1dbbad
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
81 additions
and
4 deletions
+81
-4
official/core/input_reader.py
official/core/input_reader.py
+56
-3
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+25
-1
No files found.
official/core/input_reader.py
View file @
5a3af75c
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
from
typing
import
Any
,
Callable
,
List
,
Optional
from
typing
import
Any
,
Callable
,
List
,
Optional
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
from
official.modeling.hyperparams
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
config_definitions
as
cfg
...
@@ -53,11 +54,15 @@ class InputReader:
...
@@ -53,11 +54,15 @@ class InputReader:
postprocess_fn: A optional `callable` that processes batched tensors. It
postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching.
will be executed after batching.
"""
"""
# TODO(chendouble): Support TFDS as input_path.
if
params
.
input_path
and
params
.
tfds_name
:
raise
ValueError
(
'At most one of `input_path` and `tfds_name` can be '
'specified, but got %s and %s.'
%
(
params
.
input_path
,
params
.
tfds_name
))
self
.
_shards
=
shards
self
.
_shards
=
shards
self
.
_tfds_builder
=
None
if
self
.
_shards
:
if
self
.
_shards
:
self
.
_num_files
=
len
(
self
.
_shards
)
self
.
_num_files
=
len
(
self
.
_shards
)
el
s
e
:
el
if
not
params
.
tfds_nam
e
:
self
.
_input_patterns
=
params
.
input_path
.
strip
().
split
(
','
)
self
.
_input_patterns
=
params
.
input_path
.
strip
().
split
(
','
)
self
.
_num_files
=
0
self
.
_num_files
=
0
for
input_pattern
in
self
.
_input_patterns
:
for
input_pattern
in
self
.
_input_patterns
:
...
@@ -71,6 +76,13 @@ class InputReader:
...
@@ -71,6 +76,13 @@ class InputReader:
self
.
_num_files
+=
len
(
matched_files
)
self
.
_num_files
+=
len
(
matched_files
)
if
self
.
_num_files
==
0
:
if
self
.
_num_files
==
0
:
raise
ValueError
(
'%s does not match any files.'
%
params
.
input_path
)
raise
ValueError
(
'%s does not match any files.'
%
params
.
input_path
)
else
:
if
not
params
.
tfds_split
:
raise
ValueError
(
'`tfds_name` is %s, but `tfds_split` is not specified.'
%
params
.
tfds_name
)
self
.
_tfds_builder
=
tfds
.
builder
(
params
.
tfds_name
,
data_dir
=
params
.
tfds_data_dir
)
self
.
_global_batch_size
=
params
.
global_batch_size
self
.
_global_batch_size
=
params
.
global_batch_size
self
.
_is_training
=
params
.
is_training
self
.
_is_training
=
params
.
is_training
...
@@ -78,8 +90,13 @@ class InputReader:
...
@@ -78,8 +90,13 @@ class InputReader:
self
.
_shuffle_buffer_size
=
params
.
shuffle_buffer_size
self
.
_shuffle_buffer_size
=
params
.
shuffle_buffer_size
self
.
_cache
=
params
.
cache
self
.
_cache
=
params
.
cache
self
.
_cycle_length
=
params
.
cycle_length
self
.
_cycle_length
=
params
.
cycle_length
self
.
_block_length
=
params
.
block_length
self
.
_sharding
=
params
.
sharding
self
.
_sharding
=
params
.
sharding
self
.
_examples_consume
=
params
.
examples_consume
self
.
_examples_consume
=
params
.
examples_consume
self
.
_tfds_split
=
params
.
tfds_split
self
.
_tfds_download
=
params
.
tfds_download
self
.
_tfds_as_supervised
=
params
.
tfds_as_supervised
self
.
_tfds_skip_decoding_feature
=
params
.
tfds_skip_decoding_feature
self
.
_dataset_fn
=
dataset_fn
self
.
_dataset_fn
=
dataset_fn
self
.
_decoder_fn
=
decoder_fn
self
.
_decoder_fn
=
decoder_fn
...
@@ -107,6 +124,7 @@ class InputReader:
...
@@ -107,6 +124,7 @@ class InputReader:
dataset
=
dataset
.
interleave
(
dataset
=
dataset
.
interleave
(
map_func
=
self
.
_dataset_fn
,
map_func
=
self
.
_dataset_fn
,
cycle_length
=
self
.
_cycle_length
,
cycle_length
=
self
.
_cycle_length
,
block_length
=
self
.
_block_length
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
dataset
...
@@ -131,12 +149,47 @@ class InputReader:
...
@@ -131,12 +149,47 @@ class InputReader:
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
repeat
()
return
dataset
return
dataset
def
_read_tfds
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Reads a dataset from tfds."""
if
self
.
_tfds_download
:
self
.
_tfds_builder
.
download_and_prepare
()
read_config
=
tfds
.
ReadConfig
(
interleave_cycle_length
=
self
.
_cycle_length
,
interleave_block_length
=
self
.
_block_length
,
input_context
=
input_context
)
decoders
=
{}
if
self
.
_tfds_skip_decoding_feature
:
for
skip_feature
in
self
.
_tfds_skip_decoding_feature
.
split
(
','
):
decoders
[
skip_feature
.
strip
()]
=
tfds
.
decode
.
SkipDecoding
()
dataset
=
self
.
_tfds_builder
.
as_dataset
(
split
=
self
.
_tfds_split
,
shuffle_files
=
self
.
_is_training
,
as_supervised
=
self
.
_tfds_as_supervised
,
decoders
=
decoders
,
read_config
=
read_config
)
return
dataset
@
property
def
tfds_info
(
self
)
->
tfds
.
core
.
DatasetInfo
:
"""Returns TFDS dataset info, if available."""
if
self
.
_tfds_builder
:
return
self
.
_tfds_builder
.
info
else
:
raise
ValueError
(
'tfds_info is not available, because the dataset '
'is not loaded from tfds.'
)
def
read
(
def
read
(
self
,
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
)
->
tf
.
data
.
Dataset
:
"""Generates a tf.data.Dataset object."""
"""Generates a tf.data.Dataset object."""
if
self
.
_num_files
>
1
:
if
self
.
_tfds_builder
:
dataset
=
self
.
_read_tfds
(
input_context
)
elif
self
.
_num_files
>
1
:
dataset
=
self
.
_read_sharded_files
(
input_context
)
dataset
=
self
.
_read_sharded_files
(
input_context
)
else
:
else
:
assert
self
.
_num_files
==
1
assert
self
.
_num_files
==
1
...
...
official/modeling/hyperparams/config_definitions.py
View file @
5a3af75c
...
@@ -31,7 +31,12 @@ class DataConfig(base_config.Config):
...
@@ -31,7 +31,12 @@ class DataConfig(base_config.Config):
Attributes:
Attributes:
input_path: The path to the input. It can be either (1) a file pattern, or
input_path: The path to the input. It can be either (1) a file pattern, or
(2) multiple file patterns separated by comma.
(2) multiple file patterns separated by comma. It should not be specified
when the following `tfds_name` is specified.
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
specified when the above `input_path` is specified.
tfds_split: A str indicating which split of the data to load from TFDS. It
is required when above `tfds_name` is specified.
global_batch_size: The global batch size across all replicas.
global_batch_size: The global batch size across all replicas.
is_training: Whether this data is used for training or not.
is_training: Whether this data is used for training or not.
drop_remainder: Whether the last batch should be dropped in the case it has
drop_remainder: Whether the last batch should be dropped in the case it has
...
@@ -41,21 +46,40 @@ class DataConfig(base_config.Config):
...
@@ -41,21 +46,40 @@ class DataConfig(base_config.Config):
from disk on the second epoch. Requires significant memory overhead.
from disk on the second epoch. Requires significant memory overhead.
cycle_length: The number of files that will be processed concurrently when
cycle_length: The number of files that will be processed concurrently when
interleaving files.
interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
sharding: Whether sharding is used in the input pipeline.
sharding: Whether sharding is used in the input pipeline.
examples_consume: An `integer` specifying the number of examples it will
examples_consume: An `integer` specifying the number of examples it will
produce. If positive, it only takes this number of examples and raises
produce. If positive, it only takes this number of examples and raises
tf.error.OutOfRangeError after that. Default is -1, meaning it will
tf.error.OutOfRangeError after that. Default is -1, meaning it will
exhaust all the examples in the dataset.
exhaust all the examples in the dataset.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True,
the returned tf.data.Dataset will have a 2-tuple structure (input, label)
according to builder.info.supervised_keys; if False, the default,
the returned tf.data.Dataset will have a dictionary with all the features.
tfds_skip_decoding_feature: A str to indicate which features are skipped
for decoding when loading dataset from TFDS. Use comma to separate
multiple features. The main use case is to skip the image/video decoding
for better performance.
"""
"""
input_path
:
str
=
""
input_path
:
str
=
""
tfds_name
:
str
=
""
tfds_split
:
str
=
""
global_batch_size
:
int
=
0
global_batch_size
:
int
=
0
is_training
:
bool
=
None
is_training
:
bool
=
None
drop_remainder
:
bool
=
True
drop_remainder
:
bool
=
True
shuffle_buffer_size
:
int
=
100
shuffle_buffer_size
:
int
=
100
cache
:
bool
=
False
cache
:
bool
=
False
cycle_length
:
int
=
8
cycle_length
:
int
=
8
block_length
:
int
=
1
sharding
:
bool
=
True
sharding
:
bool
=
True
examples_consume
:
int
=
-
1
examples_consume
:
int
=
-
1
tfds_data_dir
:
str
=
""
tfds_download
:
bool
=
False
tfds_as_supervised
:
bool
=
False
tfds_skip_decoding_feature
:
str
=
""
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment