Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
469339ec
Commit
469339ec
authored
Jun 15, 2021
by
Frederick Liu
Committed by
A. Unique TensorFlower
Jun 15, 2021
Browse files
Internal change
PiperOrigin-RevId: 379618127
parent
bcbce005
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
10 deletions
+15
-10
official/core/input_reader.py
official/core/input_reader.py
+15
-10
No files found.
official/core/input_reader.py
View file @
469339ec
...
@@ -110,14 +110,15 @@ class InputReader:
...
@@ -110,14 +110,15 @@ class InputReader:
self
.
_parser_fn
=
parser_fn
self
.
_parser_fn
=
parser_fn
self
.
_transform_and_batch_fn
=
transform_and_batch_fn
self
.
_transform_and_batch_fn
=
transform_and_batch_fn
self
.
_postprocess_fn
=
postprocess_fn
self
.
_postprocess_fn
=
postprocess_fn
self
.
_seed
=
params
.
seed
# When tf.data service is enabled, each data service worker should get
# When tf.data service is enabled, each data service worker should get
# different random seeds. Thus, we set `seed` to None.
# different random seeds. Thus, we set `seed` to None.
if
params
.
seed
is
not
None
:
# Sharding should also be disabled because tf data service handles how
self
.
_seed
=
params
.
seed
# each worker shard data with `processing_mode` in distribute method.
elif
params
.
enable_tf_data_service
:
if
params
.
enable_tf_data_service
:
self
.
_seed
=
_get_random_integer
()
else
:
self
.
_seed
=
None
self
.
_seed
=
None
self
.
_sharding
=
False
self
.
_enable_tf_data_service
=
(
self
.
_enable_tf_data_service
=
(
params
.
enable_tf_data_service
and
params
.
tf_data_service_address
)
params
.
enable_tf_data_service
and
params
.
tf_data_service_address
)
...
@@ -181,16 +182,21 @@ class InputReader:
...
@@ -181,16 +182,21 @@ class InputReader:
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# If cache is enabled, `reshuffle_each_iteration` is set to False,
# because we will read the same cached data in every iteration anyway.
# because we will read the same cached data in every iteration anyway.
if
self
.
_is_training
:
if
self
.
_is_training
:
# We need a seed to shuffle the files so that when each TPU workers gets
# its own shard the files do not overlap.
if
self
.
_sharding
and
self
.
_seed
is
None
:
seed
=
_get_random_integer
()
else
:
seed
=
self
.
_seed
dataset
=
dataset
.
shuffle
(
dataset
=
dataset
.
shuffle
(
len
(
matched_files
),
len
(
matched_files
),
seed
=
self
.
_
seed
,
seed
=
seed
,
reshuffle_each_iteration
=
True
if
not
self
.
_cache
else
False
)
reshuffle_each_iteration
=
True
if
not
self
.
_cache
else
False
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
# handled inside tf.data service.
if
self
.
_sharding
and
input_context
and
(
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
and
input_context
.
num_input_pipelines
>
1
):
not
self
.
_enable_tf_data_service
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
input_context
.
input_pipeline_id
)
...
@@ -226,8 +232,7 @@ class InputReader:
...
@@ -226,8 +232,7 @@ class InputReader:
# Do not enable sharding if tf.data service is enabled, as sharding will be
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
# handled inside tf.data service.
if
self
.
_sharding
and
input_context
and
(
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
and
input_context
.
num_input_pipelines
>
1
):
not
self
.
_enable_tf_data_service
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
input_context
.
input_pipeline_id
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment