Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ac0a29f6
Commit
ac0a29f6
authored
Mar 14, 2022
by
Hao Wu
Committed by
A. Unique TensorFlower
Mar 14, 2022
Browse files
Internal change
PiperOrigin-RevId: 434505499
parent
9adaa571
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
4 deletions
+12
-4
official/core/input_reader.py
official/core/input_reader.py
+12
-4
No files found.
official/core/input_reader.py
View file @
ac0a29f6
...
@@ -285,8 +285,17 @@ class InputReader:
...
@@ -285,8 +285,17 @@ class InputReader:
if
self
.
_enable_tf_data_service
:
if
self
.
_enable_tf_data_service
:
# Add a random seed as the tf.data service job name suffix, so tf.data
# Add a random seed as the tf.data service job name suffix, so tf.data
# service doesn't reuse the previous state if TPU worker gets preempted.
# service doesn't reuse the previous state if TPU worker gets preempted.
# It's necessary to add global batch size into the tf data service job
# name because when tuning batch size with vizier and tf data service is
# also enable, the tf data servce job name should be different for
# different vizier trials since once batch size is changed, from the
# tf.data perspective, the dataset is a different instance, and a
# different job name should be used for tf data service. Otherwise, the
# model would read tensors from the incorrect tf data service job, which
# would causes dimension mismatch on the batch size dimension.
self
.
_tf_data_service_job_name
=
(
self
.
_tf_data_service_job_name
=
(
params
.
tf_data_service_job_name
+
str
(
self
.
static_randnum
))
f
'
{
params
.
tf_data_service_job_name
}
_bs
{
params
.
global_batch_size
}
_'
f
'
{
self
.
static_randnum
}
'
)
self
.
_enable_round_robin_tf_data_service
=
params
.
get
(
self
.
_enable_round_robin_tf_data_service
=
params
.
get
(
'enable_round_robin_tf_data_service'
,
False
)
'enable_round_robin_tf_data_service'
,
False
)
...
@@ -463,9 +472,8 @@ class InputReader:
...
@@ -463,9 +472,8 @@ class InputReader:
dataset
:
Optional
[
tf
.
data
.
Dataset
]
=
None
)
->
tf
.
data
.
Dataset
:
dataset
:
Optional
[
tf
.
data
.
Dataset
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Generates a tf.data.Dataset object."""
"""Generates a tf.data.Dataset object."""
if
dataset
is
None
:
if
dataset
is
None
:
dataset
=
self
.
_read_data_source
(
dataset
=
self
.
_read_data_source
(
self
.
_matched_files
,
self
.
_dataset_fn
,
self
.
_matched_files
,
self
.
_dataset_fn
,
input_context
,
input_context
,
self
.
_tfds_builder
)
self
.
_tfds_builder
)
dataset
=
self
.
_decode_and_parse_dataset
(
dataset
,
self
.
_global_batch_size
,
dataset
=
self
.
_decode_and_parse_dataset
(
dataset
,
self
.
_global_batch_size
,
input_context
)
input_context
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
dataset
=
_maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment