Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f15591f4
Commit
f15591f4
authored
Aug 27, 2020
by
Ruoxin Sang
Committed by
A. Unique TensorFlower
Aug 27, 2020
Browse files
Internal change
PiperOrigin-RevId: 328789639
parent
7f770b62
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
4 deletions
+57
-4
official/common/flags.py
official/common/flags.py
+3
-0
official/core/input_reader.py
official/core/input_reader.py
+30
-3
official/core/train_utils.py
official/core/train_utils.py
+9
-1
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+15
-0
No files found.
official/common/flags.py
View file @
f15591f4
...
...
@@ -75,3 +75,6 @@ def define_flags():
help
=
'The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.'
)
flags
.
DEFINE_string
(
'tf_data_service'
,
default
=
None
,
help
=
'The tf.data service address'
)
official/core/input_reader.py
View file @
f15591f4
...
...
@@ -100,6 +100,7 @@ class InputReader:
self
.
_cache
=
params
.
cache
self
.
_cycle_length
=
params
.
cycle_length
self
.
_block_length
=
params
.
block_length
self
.
_deterministic
=
params
.
deterministic
self
.
_sharding
=
params
.
sharding
self
.
_examples_consume
=
params
.
examples_consume
self
.
_tfds_split
=
params
.
tfds_split
...
...
@@ -114,6 +115,11 @@ class InputReader:
self
.
_postprocess_fn
=
postprocess_fn
self
.
_seed
=
_get_random_integer
()
self
.
_enable_tf_data_service
=
(
params
.
enable_tf_data_service
and
params
.
tf_data_service_address
)
self
.
_tf_data_service_address
=
params
.
tf_data_service_address
self
.
_tf_data_service_job_name
=
params
.
tf_data_service_job_name
def
_read_sharded_files
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
...
...
@@ -134,8 +140,11 @@ class InputReader:
seed
=
self
.
_seed
,
reshuffle_each_iteration
=
True
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
input_context
.
num_input_pipelines
>
1
and
not
self
.
_enable_tf_data_service
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
if
self
.
_is_training
:
...
...
@@ -145,7 +154,8 @@ class InputReader:
map_func
=
self
.
_dataset_fn
,
cycle_length
=
self
.
_cycle_length
,
block_length
=
self
.
_block_length
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
,
deterministic
=
self
.
_deterministic
)
return
dataset
def
_read_single_file
(
...
...
@@ -161,8 +171,11 @@ class InputReader:
options
.
experimental_distribute
.
auto_shard_policy
=
(
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
)
dataset
=
dataset
.
with_options
(
options
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
input_context
.
num_input_pipelines
>
1
and
not
self
.
_enable_tf_data_service
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
if
self
.
_is_training
:
...
...
@@ -243,4 +256,18 @@ class InputReader:
per_replica_batch_size
,
drop_remainder
=
self
.
_drop_remainder
)
dataset
=
maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
if
self
.
_enable_tf_data_service
:
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
service
.
distribute
(
processing_mode
=
'parallel_epochs'
,
service
=
self
.
_tf_data_service_address
,
job_name
=
self
.
_tf_data_service_job_name
))
dataset
=
dataset
.
prefetch
(
buffer_size
=
tf
.
data
.
experimental
.
AUTOTUNE
)
if
self
.
_deterministic
is
not
None
:
options
=
tf
.
data
.
Options
()
options
.
experimental_deterministic
=
self
.
_deterministic
dataset
=
dataset
.
with_options
(
options
)
return
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
official/core/train_utils.py
View file @
f15591f4
...
...
@@ -60,10 +60,18 @@ def parse_configuration(flags_obj):
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
# 3. Override the TPU address.
# 3. Override the TPU
address and tf.data service
address.
params
.
override
({
'runtime'
:
{
'tpu'
:
flags_obj
.
tpu
,
},
'task'
:
{
'train_data'
:
{
'tf_data_service_address'
:
flags_obj
.
tf_data_service
,
},
'validation_data'
:
{
'tf_data_service_address'
:
flags_obj
.
tf_data_service
,
}
}
})
...
...
official/modeling/hyperparams/config_definitions.py
View file @
f15591f4
...
...
@@ -48,11 +48,22 @@ class DataConfig(base_config.Config):
interleaving files.
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
deterministic: A boolean controlling whether determinism should be enforced.
sharding: Whether sharding is used in the input pipeline.
examples_consume: An `integer` specifying the number of examples it will
produce. If positive, it only takes this number of examples and raises
tf.error.OutOfRangeError after that. Default is -1, meaning it will
exhaust all the examples in the dataset.
enable_tf_data_service: A boolean indicating whether to enable tf.data
service for the input pipeline.
tf_data_service_address: The URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This
argument makes it possible for multiple datasets to share the same job.
The default behavior is that the dataset creates anonymous, exclusively
owned jobs.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
...
...
@@ -74,8 +85,12 @@ class DataConfig(base_config.Config):
cache
:
bool
=
False
cycle_length
:
int
=
8
block_length
:
int
=
1
deterministic
:
Optional
[
bool
]
=
None
sharding
:
bool
=
True
examples_consume
:
int
=
-
1
enable_tf_data_service
:
bool
=
False
tf_data_service_address
:
Optional
[
str
]
=
None
tf_data_service_job_name
:
Optional
[
str
]
=
None
tfds_data_dir
:
str
=
""
tfds_download
:
bool
=
False
tfds_as_supervised
:
bool
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment