Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0f0b060c
Commit
0f0b060c
authored
Sep 23, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 23, 2021
Browse files
Refactor input reader
PiperOrigin-RevId: 398593113
parent
201d523a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
244 additions
and
183 deletions
+244
-183
official/core/input_reader.py
official/core/input_reader.py
+235
-180
official/vision/beta/dataloaders/input_reader.py
official/vision/beta/dataloaders/input_reader.py
+9
-3
No files found.
official/core/input_reader.py
View file @
0f0b060c
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
"""A common dataset reader."""
"""A common dataset reader."""
import
random
import
random
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
,
Dict
,
Sequence
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Text
,
Union
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -34,6 +34,154 @@ def _maybe_map_fn(dataset: tf.data.Dataset,
...
@@ -34,6 +34,154 @@ def _maybe_map_fn(dataset: tf.data.Dataset,
fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
def
match_files
(
input_path
:
Union
[
Sequence
[
str
],
str
])
->
List
[
str
]:
"""Matches files from an input_path."""
matched_files
=
[]
# Read dataset from files.
usage
=
(
'`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s'
)
if
isinstance
(
input_path
,
str
):
input_path_list
=
[
input_path
]
elif
isinstance
(
input_path
,
(
list
,
tuple
)):
if
any
(
not
isinstance
(
x
,
str
)
for
x
in
input_path
):
raise
ValueError
(
usage
%
input_path
)
input_path_list
=
input_path
else
:
raise
ValueError
(
usage
%
input_path
)
for
input_path
in
input_path_list
:
input_patterns
=
input_path
.
strip
().
split
(
','
)
for
input_pattern
in
input_patterns
:
input_pattern
=
input_pattern
.
strip
()
if
not
input_pattern
:
continue
if
'*'
in
input_pattern
or
'?'
in
input_pattern
:
tmp_matched_files
=
tf
.
io
.
gfile
.
glob
(
input_pattern
)
if
not
tmp_matched_files
:
raise
ValueError
(
'%s does not match any files.'
%
input_pattern
)
matched_files
.
extend
(
tmp_matched_files
)
else
:
matched_files
.
append
(
input_pattern
)
if
not
matched_files
:
raise
ValueError
(
'%s does not match any files.'
%
input_path
)
return
matched_files
def
_read_files_then_shard
(
matched_files
:
List
[
str
],
dataset_fn
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
sharding
:
bool
=
False
,
repeat
:
bool
=
False
)
->
tf
.
data
.
Dataset
:
"""Sends all data files to every worker and then shard by data."""
dataset
=
dataset_fn
(
matched_files
)
# When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding
# so that same input file is sent to all workers.
options
=
tf
.
data
.
Options
()
options
.
experimental_distribute
.
auto_shard_policy
=
(
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
)
dataset
=
dataset
.
with_options
(
options
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
if
repeat
:
dataset
=
dataset
.
repeat
()
return
dataset
def
_shard_files_then_read
(
matched_files
:
List
[
str
],
dataset_fn
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
seed
:
Optional
[
Union
[
int
,
tf
.
Tensor
]]
=
None
,
is_training
:
bool
=
False
,
sharding
:
bool
=
False
,
cache
:
bool
=
False
,
cycle_length
:
Optional
[
int
]
=
None
,
block_length
:
Optional
[
int
]
=
None
,
deterministic
:
bool
=
False
)
->
tf
.
data
.
Dataset
:
"""Shards the data files and then sent a split to every worker to read."""
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
matched_files
)
# Shuffle and repeat at file level.
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway.
if
is_training
:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if
sharding
and
seed
is
None
:
seed
=
_get_random_integer
()
dataset
=
dataset
.
shuffle
(
len
(
matched_files
),
seed
=
seed
,
reshuffle_each_iteration
=
True
if
not
cache
else
False
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if
is_training
and
not
cache
:
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
interleave
(
map_func
=
dataset_fn
,
cycle_length
=
cycle_length
,
block_length
=
block_length
,
num_parallel_calls
=
(
cycle_length
if
cycle_length
else
tf
.
data
.
experimental
.
AUTOTUNE
),
deterministic
=
deterministic
)
return
dataset
def
_read_tfds
(
tfds_builder
:
tfds
.
core
.
DatasetBuilder
,
tfds_split
:
Text
,
tfds_skip_decoding_feature
:
Text
,
tfds_as_supervised
:
bool
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
seed
:
Optional
[
Union
[
int
,
tf
.
Tensor
]]
=
None
,
is_training
:
bool
=
False
,
cache
:
bool
=
False
,
cycle_length
:
Optional
[
int
]
=
None
,
block_length
:
Optional
[
int
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Reads a dataset from tfds."""
# No op if exist.
tfds_builder
.
download_and_prepare
()
read_config
=
tfds
.
ReadConfig
(
interleave_cycle_length
=
cycle_length
,
interleave_block_length
=
block_length
,
input_context
=
input_context
,
shuffle_seed
=
seed
)
decoders
=
{}
if
tfds_skip_decoding_feature
:
for
skip_feature
in
tfds_skip_decoding_feature
.
split
(
','
):
decoders
[
skip_feature
.
strip
()]
=
tfds
.
decode
.
SkipDecoding
()
dataset
=
tfds_builder
.
as_dataset
(
split
=
tfds_split
,
shuffle_files
=
is_training
,
as_supervised
=
tfds_as_supervised
,
decoders
=
decoders
,
read_config
=
read_config
)
if
is_training
and
not
cache
:
dataset
=
dataset
.
repeat
()
return
dataset
class
InputReader
:
class
InputReader
:
"""Input reader that returns a tf.data.Dataset instance."""
"""Input reader that returns a tf.data.Dataset instance."""
...
@@ -90,16 +238,7 @@ class InputReader:
...
@@ -90,16 +238,7 @@ class InputReader:
self
.
_tfds_builder
=
None
self
.
_tfds_builder
=
None
self
.
_matched_files
=
None
self
.
_matched_files
=
None
if
params
.
input_path
:
if
not
params
.
input_path
:
# we want to combine / mix datasets
if
isinstance
(
params
.
input_path
,
cfg
.
base_config
.
Config
):
self
.
_matched_files
=
{}
for
k
,
v
in
params
.
input_path
.
as_dict
().
items
():
self
.
_matched_files
[
k
]
=
self
.
_match_files
(
v
)
# single dataset
else
:
self
.
_matched_files
=
self
.
_match_files
(
params
.
input_path
)
else
:
# Read dataset from TFDS.
# Read dataset from TFDS.
if
not
params
.
tfds_split
:
if
not
params
.
tfds_split
:
raise
ValueError
(
raise
ValueError
(
...
@@ -107,6 +246,8 @@ class InputReader:
...
@@ -107,6 +246,8 @@ class InputReader:
params
.
tfds_name
)
params
.
tfds_name
)
self
.
_tfds_builder
=
tfds
.
builder
(
self
.
_tfds_builder
=
tfds
.
builder
(
params
.
tfds_name
,
data_dir
=
params
.
tfds_data_dir
)
params
.
tfds_name
,
data_dir
=
params
.
tfds_data_dir
)
else
:
self
.
_matched_files
=
self
.
get_files
(
params
.
input_path
)
self
.
_global_batch_size
=
params
.
global_batch_size
self
.
_global_batch_size
=
params
.
global_batch_size
self
.
_is_training
=
params
.
is_training
self
.
_is_training
=
params
.
is_training
...
@@ -149,145 +290,6 @@ class InputReader:
...
@@ -149,145 +290,6 @@ class InputReader:
self
.
_enable_round_robin_tf_data_service
=
params
.
get
(
self
.
_enable_round_robin_tf_data_service
=
params
.
get
(
'enable_round_robin_tf_data_service'
,
False
)
'enable_round_robin_tf_data_service'
,
False
)
def
_match_files
(
self
,
input_path
:
Union
[
Sequence
[
str
],
str
])
->
List
[
str
]:
"""Matches files from an input_path."""
matched_files
=
[]
# Read dataset from files.
usage
=
(
'`input_path` should be either (1) a str indicating a file '
'path/pattern, or (2) a str indicating multiple file '
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
'"a,b,c", or (3) a list of str, each of which is a file '
'path/pattern or multiple file paths/patterns separated by '
'comma, but got: %s'
)
if
isinstance
(
input_path
,
str
):
input_path_list
=
[
input_path
]
elif
isinstance
(
input_path
,
(
list
,
tuple
)):
if
any
(
not
isinstance
(
x
,
str
)
for
x
in
input_path
):
raise
ValueError
(
usage
%
input_path
)
input_path_list
=
input_path
else
:
raise
ValueError
(
usage
%
input_path
)
for
input_path
in
input_path_list
:
input_patterns
=
input_path
.
strip
().
split
(
','
)
for
input_pattern
in
input_patterns
:
input_pattern
=
input_pattern
.
strip
()
if
not
input_pattern
:
continue
if
'*'
in
input_pattern
or
'?'
in
input_pattern
:
tmp_matched_files
=
tf
.
io
.
gfile
.
glob
(
input_pattern
)
if
not
tmp_matched_files
:
raise
ValueError
(
'%s does not match any files.'
%
input_pattern
)
matched_files
.
extend
(
tmp_matched_files
)
else
:
matched_files
.
append
(
input_pattern
)
if
not
matched_files
:
raise
ValueError
(
'%s does not match any files.'
%
input_path
)
return
matched_files
def
_shard_files_then_read
(
self
,
matched_files
:
List
[
str
],
dataset_fn
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Shards the data files and then sent a split to every worker to read."""
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
matched_files
)
# Shuffle and repeat at file level.
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway.
if
self
.
_is_training
:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if
self
.
_sharding
and
self
.
_seed
is
None
:
seed
=
_get_random_integer
()
else
:
seed
=
self
.
_seed
dataset
=
dataset
.
shuffle
(
len
(
matched_files
),
seed
=
seed
,
reshuffle_each_iteration
=
True
if
not
self
.
_cache
else
False
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if
self
.
_is_training
and
not
self
.
_cache
:
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
interleave
(
map_func
=
dataset_fn
,
cycle_length
=
self
.
_cycle_length
,
block_length
=
self
.
_block_length
,
num_parallel_calls
=
(
self
.
_cycle_length
if
self
.
_cycle_length
else
tf
.
data
.
experimental
.
AUTOTUNE
),
deterministic
=
self
.
_deterministic
)
return
dataset
def
_read_files_then_shard
(
self
,
matched_files
:
List
[
str
],
dataset_fn
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Sends all data files to every worker and then shard by data."""
dataset
=
dataset_fn
(
matched_files
)
# When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding
# so that same input file is sent to all workers.
options
=
tf
.
data
.
Options
()
options
.
experimental_distribute
.
auto_shard_policy
=
(
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
)
dataset
=
dataset
.
with_options
(
options
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if
self
.
_is_training
and
not
self
.
_cache
:
dataset
=
dataset
.
repeat
()
return
dataset
def
_read_tfds
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Reads a dataset from tfds."""
# No op if exist.
self
.
_tfds_builder
.
download_and_prepare
()
read_config
=
tfds
.
ReadConfig
(
interleave_cycle_length
=
self
.
_cycle_length
,
interleave_block_length
=
self
.
_block_length
,
input_context
=
input_context
,
shuffle_seed
=
self
.
_seed
)
decoders
=
{}
if
self
.
_tfds_skip_decoding_feature
:
for
skip_feature
in
self
.
_tfds_skip_decoding_feature
.
split
(
','
):
decoders
[
skip_feature
.
strip
()]
=
tfds
.
decode
.
SkipDecoding
()
dataset
=
self
.
_tfds_builder
.
as_dataset
(
split
=
self
.
_tfds_split
,
shuffle_files
=
self
.
_is_training
,
as_supervised
=
self
.
_tfds_as_supervised
,
decoders
=
decoders
,
read_config
=
read_config
)
# If cache is enabled, we will call `repeat()` later after `cache()`.
if
self
.
_is_training
and
not
self
.
_cache
:
dataset
=
dataset
.
repeat
()
return
dataset
@
property
@
property
def
tfds_info
(
self
)
->
tfds
.
core
.
DatasetInfo
:
def
tfds_info
(
self
)
->
tfds
.
core
.
DatasetInfo
:
"""Returns TFDS dataset info, if available."""
"""Returns TFDS dataset info, if available."""
...
@@ -297,14 +299,27 @@ class InputReader:
...
@@ -297,14 +299,27 @@ class InputReader:
raise
ValueError
(
'tfds_info is not available, because the dataset '
raise
ValueError
(
'tfds_info is not available, because the dataset '
'is not loaded from tfds.'
)
'is not loaded from tfds.'
)
def
_read_decode_and_parse_dataset
(
def
get_files
(
self
,
input_path
):
"""Gets matched files. Can be overridden by subclasses."""
if
not
input_path
:
return
None
# we want to combine / mix datasets
if
isinstance
(
input_path
,
cfg
.
base_config
.
Config
):
matched_files
=
{}
for
k
,
v
in
input_path
.
as_dict
().
items
():
matched_files
[
k
]
=
match_files
(
v
)
# single dataset
else
:
matched_files
=
match_files
(
input_path
)
return
matched_files
def
_read_data_source
(
self
,
self
,
matched_files
:
Union
[
Dict
[
str
,
List
[
str
]],
List
[
str
]],
matched_files
:
Union
[
Dict
[
str
,
List
[
str
]],
List
[
str
]],
dataset_fn
,
dataset_fn
,
batch_size
:
int
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
tfds_builder
:
bool
=
False
)
->
tf
.
data
.
Dataset
:
tfds_builder
:
Optional
[
tfds
.
core
.
DatasetBuilder
]
=
None
)
:
"""Re
turns a tf.data.Dataset object after reading, decoding, and parsing
."""
"""Re
ads the data source (files/tfds) to a dataset
."""
def
_files_to_dataset
(
files
:
List
[
str
])
->
tf
.
data
.
Dataset
:
def
_files_to_dataset
(
files
:
List
[
str
])
->
tf
.
data
.
Dataset
:
if
len
(
files
)
>
1
:
if
len
(
files
)
>
1
:
...
@@ -314,15 +329,66 @@ class InputReader:
...
@@ -314,15 +329,66 @@ class InputReader:
'%d. We will send all input files to every worker. '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.'
,
len
(
files
),
'Please consider sharding your data into more files.'
,
len
(
files
),
input_context
.
num_input_pipelines
)
input_context
.
num_input_pipelines
)
return
self
.
_read_files_then_shard
(
files
,
dataset_fn
,
input_context
)
return
_read_files_then_shard
(
files
,
dataset_fn
,
input_context
,
sharding
=
self
.
_sharding
,
repeat
=
self
.
_is_training
and
not
self
.
_cache
)
else
:
else
:
return
self
.
_shard_files_then_read
(
files
,
dataset_fn
,
input_context
)
return
_shard_files_then_read
(
files
,
dataset_fn
,
input_context
,
seed
=
self
.
_seed
,
is_training
=
self
.
_is_training
,
sharding
=
self
.
_sharding
,
cache
=
self
.
_cache
,
cycle_length
=
self
.
_cycle_length
,
block_length
=
self
.
_block_length
,
deterministic
=
self
.
_deterministic
)
elif
len
(
files
)
==
1
:
elif
len
(
files
)
==
1
:
return
self
.
_read_files_then_shard
(
files
,
dataset_fn
,
input_context
)
return
_read_files_then_shard
(
files
,
dataset_fn
,
input_context
,
sharding
=
self
.
_sharding
,
repeat
=
self
.
_is_training
and
not
self
.
_cache
)
else
:
else
:
raise
ValueError
(
'It is unexpected that `tfds_builder` is None and '
raise
ValueError
(
'It is unexpected that `tfds_builder` is None and '
'there is also no `files`.'
)
'there is also no `files`.'
)
if
tfds_builder
:
dataset
=
_read_tfds
(
tfds_builder
=
self
.
_tfds_builder
,
tfds_split
=
self
.
_tfds_split
,
tfds_skip_decoding_feature
=
self
.
_tfds_skip_decoding_feature
,
tfds_as_supervised
=
self
.
_tfds_as_supervised
,
input_context
=
input_context
,
seed
=
self
.
_seed
,
is_training
=
self
.
_is_training
,
cache
=
self
.
_cache
,
cycle_length
=
self
.
_cycle_length
,
block_length
=
self
.
_block_length
)
elif
isinstance
(
matched_files
,
(
list
,
tuple
)):
dataset
=
_files_to_dataset
(
matched_files
)
elif
isinstance
(
matched_files
,
dict
):
dataset
=
{}
for
k
,
fs
in
matched_files
.
items
():
dataset
[
k
]
=
_files_to_dataset
(
fs
)
else
:
raise
ValueError
(
'`matched_files` should be a list or dict.'
)
return
dataset
def
_decode_and_parse_dataset
(
self
,
dataset
:
Union
[
tf
.
data
.
Dataset
,
Dict
[
Text
,
tf
.
data
.
Dataset
]],
batch_size
:
int
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Returns a tf.data.Dataset object after shuffling, decoding, and parsing."""
def
_shuffle_and_decode
(
ds
):
def
_shuffle_and_decode
(
ds
):
# If cache is enabled, we will call `shuffle()` later after `cache()`.
# If cache is enabled, we will call `shuffle()` later after `cache()`.
if
self
.
_is_training
and
not
self
.
_cache
:
if
self
.
_is_training
and
not
self
.
_cache
:
...
@@ -331,20 +397,9 @@ class InputReader:
...
@@ -331,20 +397,9 @@ class InputReader:
ds
=
_maybe_map_fn
(
ds
,
self
.
_decoder_fn
)
ds
=
_maybe_map_fn
(
ds
,
self
.
_decoder_fn
)
return
ds
return
ds
if
tfds_builder
:
dataset
=
tf
.
nest
.
map_structure
(
_shuffle_and_decode
,
dataset
)
dataset
=
self
.
_read_tfds
(
input_context
)
if
tf
.
nest
.
is_nested
(
dataset
):
dataset
=
_shuffle_and_decode
(
dataset
)
dataset
=
self
.
_combine_fn
(
dataset
)
elif
isinstance
(
matched_files
,
(
list
,
tuple
)):
dataset
=
_files_to_dataset
(
matched_files
)
dataset
=
_shuffle_and_decode
(
dataset
)
elif
isinstance
(
matched_files
,
dict
):
datasets
=
{}
for
k
,
fs
in
matched_files
.
items
():
datasets
[
k
]
=
_files_to_dataset
(
fs
)
datasets
[
k
]
=
_shuffle_and_decode
(
datasets
[
k
])
dataset
=
self
.
_combine_fn
(
datasets
)
else
:
raise
ValueError
(
'`matched_files` should be a list or dict.'
)
if
self
.
_sample_fn
is
not
None
:
if
self
.
_sample_fn
is
not
None
:
dataset
=
dataset
.
apply
(
self
.
_sample_fn
)
dataset
=
dataset
.
apply
(
self
.
_sample_fn
)
...
@@ -403,16 +458,16 @@ class InputReader:
...
@@ -403,16 +458,16 @@ class InputReader:
job_name
=
self
.
_tf_data_service_job_name
))
job_name
=
self
.
_tf_data_service_job_name
))
return
dataset
return
dataset
def
read
(
def
read
(
self
,
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
dataset
:
Optional
[
tf
.
data
.
Dataset
]
=
None
)
->
tf
.
data
.
Dataset
:
)
->
tf
.
data
.
Dataset
:
"""Generates a tf.data.Dataset object."""
"""Generates a tf.data.Dataset object."""
dataset
=
self
.
_read_decode_and_parse_dataset
(
self
.
_matched_files
,
if
dataset
is
None
:
self
.
_dataset_fn
,
dataset
=
self
.
_read_data_source
(
self
.
_global_batch_size
,
self
.
_matched_files
,
self
.
_dataset_fn
,
input_context
,
input_context
,
self
.
_tfds_builder
)
self
.
_tfds_builder
)
dataset
=
self
.
_decode_and_parse_dataset
(
dataset
,
self
.
_global_batch_size
,
input_context
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
dataset
=
self
.
_maybe_apply_data_service
(
dataset
,
input_context
)
dataset
=
self
.
_maybe_apply_data_service
(
dataset
,
input_context
)
...
...
official/vision/beta/dataloaders/input_reader.py
View file @
0f0b060c
...
@@ -113,7 +113,7 @@ class CombinationDatasetInputReader(input_reader.InputReader):
...
@@ -113,7 +113,7 @@ class CombinationDatasetInputReader(input_reader.InputReader):
self
.
_pseudo_label_file_pattern
=
params
.
pseudo_label_data
.
input_path
self
.
_pseudo_label_file_pattern
=
params
.
pseudo_label_data
.
input_path
self
.
_pseudo_label_dataset_fn
=
pseudo_label_dataset_fn
self
.
_pseudo_label_dataset_fn
=
pseudo_label_dataset_fn
self
.
_pseudo_label_data_ratio
=
params
.
pseudo_label_data
.
data_ratio
self
.
_pseudo_label_data_ratio
=
params
.
pseudo_label_data
.
data_ratio
self
.
_pseudo_label_matched_files
=
self
.
_
match_files
(
self
.
_pseudo_label_matched_files
=
input_reader
.
match_files
(
self
.
_pseudo_label_file_pattern
)
self
.
_pseudo_label_file_pattern
)
if
not
self
.
_drop_remainder
:
if
not
self
.
_drop_remainder
:
raise
ValueError
(
raise
ValueError
(
...
@@ -134,14 +134,20 @@ class CombinationDatasetInputReader(input_reader.InputReader):
...
@@ -134,14 +134,20 @@ class CombinationDatasetInputReader(input_reader.InputReader):
'resulting in a 0 batch size for one of the datasets.'
.
format
(
'resulting in a 0 batch size for one of the datasets.'
.
format
(
self
.
_global_batch_size
,
self
.
_pseudo_label_data_ratio
))
self
.
_global_batch_size
,
self
.
_pseudo_label_data_ratio
))
labeled_dataset
=
self
.
_read_decode_and_parse_dataset
(
def
_read_decode_and_parse_dataset
(
matched_files
,
dataset_fn
,
batch_size
,
input_context
,
tfds_builder
):
dataset
=
self
.
_read_data_source
(
matched_files
,
dataset_fn
,
input_context
,
tfds_builder
)
return
self
.
_decode_and_parse_dataset
(
dataset
,
batch_size
,
input_context
)
labeled_dataset
=
_read_decode_and_parse_dataset
(
matched_files
=
self
.
_matched_files
,
matched_files
=
self
.
_matched_files
,
dataset_fn
=
self
.
_dataset_fn
,
dataset_fn
=
self
.
_dataset_fn
,
batch_size
=
labeled_batch_size
,
batch_size
=
labeled_batch_size
,
input_context
=
input_context
,
input_context
=
input_context
,
tfds_builder
=
self
.
_tfds_builder
)
tfds_builder
=
self
.
_tfds_builder
)
pseudo_labeled_dataset
=
self
.
_read_decode_and_parse_dataset
(
pseudo_labeled_dataset
=
_read_decode_and_parse_dataset
(
matched_files
=
self
.
_pseudo_label_matched_files
,
matched_files
=
self
.
_pseudo_label_matched_files
,
dataset_fn
=
self
.
_pseudo_label_dataset_fn
,
dataset_fn
=
self
.
_pseudo_label_dataset_fn
,
batch_size
=
pl_batch_size
,
batch_size
=
pl_batch_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment