Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5b60084d
Commit
5b60084d
authored
Jan 19, 2021
by
Ruoxin Sang
Committed by
A. Unique TensorFlower
Jan 19, 2021
Browse files
Internal change
PiperOrigin-RevId: 352730170
parent
e31d3f37
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
7 deletions
+40
-7
official/core/input_reader.py
official/core/input_reader.py
+40
-7
No files found.
official/core/input_reader.py
View file @
5b60084d
...
...
@@ -30,6 +30,10 @@ def _get_random_integer():
class
InputReader
:
"""Input reader that returns a tf.data.Dataset instance."""
# A static random number which is the same across different InputReader
# instances.
static_randnum
=
_get_random_integer
()
def
__init__
(
self
,
params
:
cfg
.
DataConfig
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
...
...
@@ -136,7 +140,13 @@ class InputReader:
self
.
_enable_tf_data_service
=
(
params
.
enable_tf_data_service
and
params
.
tf_data_service_address
)
self
.
_tf_data_service_address
=
params
.
tf_data_service_address
self
.
_tf_data_service_job_name
=
params
.
tf_data_service_job_name
if
self
.
_enable_tf_data_service
:
# Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted.
self
.
_tf_data_service_job_name
=
(
params
.
tf_data_service_job_name
+
str
(
self
.
static_randnum
))
self
.
_enable_round_robin_tf_data_service
=
params
.
get
(
'enable_round_robin_tf_data_service'
,
False
)
def
_shard_files_then_read
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
...
...
@@ -276,12 +286,35 @@ class InputReader:
dataset
=
maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
if
self
.
_enable_tf_data_service
:
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
service
.
distribute
(
processing_mode
=
'parallel_epochs'
,
service
=
self
.
_tf_data_service_address
,
job_name
=
self
.
_tf_data_service_job_name
))
if
self
.
_enable_tf_data_service
and
input_context
:
if
self
.
_enable_round_robin_tf_data_service
:
replicas_per_input_pipeline
=
input_context
.
num_replicas_in_sync
//
(
input_context
.
num_input_pipelines
)
base_consumer_index
=
input_context
.
input_pipeline_id
*
(
replicas_per_input_pipeline
)
num_consumers
=
input_context
.
num_input_pipelines
*
(
replicas_per_input_pipeline
)
range_dataset
=
tf
.
data
.
Dataset
.
range
(
replicas_per_input_pipeline
)
dataset
=
range_dataset
.
map
(
lambda
i
:
dataset
.
apply
(
# pylint: disable=g-long-lambda
tf
.
data
.
experimental
.
service
.
distribute
(
processing_mode
=
'parallel_epochs'
,
service
=
self
.
_tf_data_service_address
,
job_name
=
self
.
_tf_data_service_job_name
,
consumer_index
=
base_consumer_index
+
i
,
num_consumers
=
num_consumers
)))
# Use parallel interleave to read multiple batches from a tf.data
# service worker in parallel.
dataset
=
dataset
.
interleave
(
lambda
x
:
x
,
cycle_length
=
replicas_per_input_pipeline
,
num_parallel_calls
=
replicas_per_input_pipeline
,
deterministic
=
True
)
else
:
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
service
.
distribute
(
processing_mode
=
'parallel_epochs'
,
service
=
self
.
_tf_data_service_address
,
job_name
=
self
.
_tf_data_service_job_name
))
if
self
.
_deterministic
is
not
None
:
options
=
tf
.
data
.
Options
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment