Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
20e2cb97
Commit
20e2cb97
authored
Aug 27, 2020
by
Ruoxin Sang
Committed by
A. Unique TensorFlower
Aug 27, 2020
Browse files
Internal change
PiperOrigin-RevId: 328789639
parent
454e12b8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
4 deletions
+57
-4
official/common/flags.py
official/common/flags.py
+3
-0
official/core/input_reader.py
official/core/input_reader.py
+30
-3
official/core/train_utils.py
official/core/train_utils.py
+9
-1
official/modeling/hyperparams/config_definitions.py
official/modeling/hyperparams/config_definitions.py
+15
-0
No files found.
official/common/flags.py
View file @
20e2cb97
...
@@ -75,3 +75,6 @@ def define_flags():
...
@@ -75,3 +75,6 @@ def define_flags():
help
=
'The Cloud TPU to use for training. This should be either the name '
help
=
'The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.'
)
'url.'
)
flags
.
DEFINE_string
(
'tf_data_service'
,
default
=
None
,
help
=
'The tf.data service address'
)
official/core/input_reader.py
View file @
20e2cb97
...
@@ -100,6 +100,7 @@ class InputReader:
...
@@ -100,6 +100,7 @@ class InputReader:
self
.
_cache
=
params
.
cache
self
.
_cache
=
params
.
cache
self
.
_cycle_length
=
params
.
cycle_length
self
.
_cycle_length
=
params
.
cycle_length
self
.
_block_length
=
params
.
block_length
self
.
_block_length
=
params
.
block_length
self
.
_deterministic
=
params
.
deterministic
self
.
_sharding
=
params
.
sharding
self
.
_sharding
=
params
.
sharding
self
.
_examples_consume
=
params
.
examples_consume
self
.
_examples_consume
=
params
.
examples_consume
self
.
_tfds_split
=
params
.
tfds_split
self
.
_tfds_split
=
params
.
tfds_split
...
@@ -114,6 +115,11 @@ class InputReader:
...
@@ -114,6 +115,11 @@ class InputReader:
self
.
_postprocess_fn
=
postprocess_fn
self
.
_postprocess_fn
=
postprocess_fn
self
.
_seed
=
_get_random_integer
()
self
.
_seed
=
_get_random_integer
()
self
.
_enable_tf_data_service
=
(
params
.
enable_tf_data_service
and
params
.
tf_data_service_address
)
self
.
_tf_data_service_address
=
params
.
tf_data_service_address
self
.
_tf_data_service_job_name
=
params
.
tf_data_service_job_name
def
_read_sharded_files
(
def
_read_sharded_files
(
self
,
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
...
@@ -134,8 +140,11 @@ class InputReader:
...
@@ -134,8 +140,11 @@ class InputReader:
seed
=
self
.
_seed
,
seed
=
self
.
_seed
,
reshuffle_each_iteration
=
True
)
reshuffle_each_iteration
=
True
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
self
.
_sharding
and
input_context
and
(
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
input_context
.
num_input_pipelines
>
1
and
not
self
.
_enable_tf_data_service
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
input_context
.
input_pipeline_id
)
if
self
.
_is_training
:
if
self
.
_is_training
:
...
@@ -145,7 +154,8 @@ class InputReader:
...
@@ -145,7 +154,8 @@ class InputReader:
map_func
=
self
.
_dataset_fn
,
map_func
=
self
.
_dataset_fn
,
cycle_length
=
self
.
_cycle_length
,
cycle_length
=
self
.
_cycle_length
,
block_length
=
self
.
_block_length
,
block_length
=
self
.
_block_length
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
,
deterministic
=
self
.
_deterministic
)
return
dataset
return
dataset
def
_read_single_file
(
def
_read_single_file
(
...
@@ -161,8 +171,11 @@ class InputReader:
...
@@ -161,8 +171,11 @@ class InputReader:
options
.
experimental_distribute
.
auto_shard_policy
=
(
options
.
experimental_distribute
.
auto_shard_policy
=
(
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
)
tf
.
data
.
experimental
.
AutoShardPolicy
.
OFF
)
dataset
=
dataset
.
with_options
(
options
)
dataset
=
dataset
.
with_options
(
options
)
# Do not enable sharding if tf.data service is enabled, as sharding will be
# handled inside tf.data service.
if
self
.
_sharding
and
input_context
and
(
if
self
.
_sharding
and
input_context
and
(
input_context
.
num_input_pipelines
>
1
):
input_context
.
num_input_pipelines
>
1
and
not
self
.
_enable_tf_data_service
):
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
dataset
=
dataset
.
shard
(
input_context
.
num_input_pipelines
,
input_context
.
input_pipeline_id
)
input_context
.
input_pipeline_id
)
if
self
.
_is_training
:
if
self
.
_is_training
:
...
@@ -243,4 +256,18 @@ class InputReader:
...
@@ -243,4 +256,18 @@ class InputReader:
per_replica_batch_size
,
drop_remainder
=
self
.
_drop_remainder
)
per_replica_batch_size
,
drop_remainder
=
self
.
_drop_remainder
)
dataset
=
maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
dataset
=
maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
if
self
.
_enable_tf_data_service
:
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
service
.
distribute
(
processing_mode
=
'parallel_epochs'
,
service
=
self
.
_tf_data_service_address
,
job_name
=
self
.
_tf_data_service_job_name
))
dataset
=
dataset
.
prefetch
(
buffer_size
=
tf
.
data
.
experimental
.
AUTOTUNE
)
if
self
.
_deterministic
is
not
None
:
options
=
tf
.
data
.
Options
()
options
.
experimental_deterministic
=
self
.
_deterministic
dataset
=
dataset
.
with_options
(
options
)
return
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
official/core/train_utils.py
View file @
20e2cb97
...
@@ -60,10 +60,18 @@ def parse_configuration(flags_obj):
...
@@ -60,10 +60,18 @@ def parse_configuration(flags_obj):
params
=
hyperparams
.
override_params_dict
(
params
=
hyperparams
.
override_params_dict
(
params
,
config_file
,
is_strict
=
True
)
params
,
config_file
,
is_strict
=
True
)
# 3. Override the TPU address.
# 3. Override the TPU
address and tf.data service
address.
params
.
override
({
params
.
override
({
'runtime'
:
{
'runtime'
:
{
'tpu'
:
flags_obj
.
tpu
,
'tpu'
:
flags_obj
.
tpu
,
},
'task'
:
{
'train_data'
:
{
'tf_data_service_address'
:
flags_obj
.
tf_data_service
,
},
'validation_data'
:
{
'tf_data_service_address'
:
flags_obj
.
tf_data_service
,
}
}
}
})
})
...
...
official/modeling/hyperparams/config_definitions.py
View file @
20e2cb97
...
@@ -48,11 +48,22 @@ class DataConfig(base_config.Config):
...
@@ -48,11 +48,22 @@ class DataConfig(base_config.Config):
interleaving files.
interleaving files.
block_length: The number of consecutive elements to produce from each input
block_length: The number of consecutive elements to produce from each input
element before cycling to another input element when interleaving files.
element before cycling to another input element when interleaving files.
deterministic: A boolean controlling whether determinism should be enforced.
sharding: Whether sharding is used in the input pipeline.
sharding: Whether sharding is used in the input pipeline.
examples_consume: An `integer` specifying the number of examples it will
examples_consume: An `integer` specifying the number of examples it will
produce. If positive, it only takes this number of examples and raises
produce. If positive, it only takes this number of examples and raises
tf.error.OutOfRangeError after that. Default is -1, meaning it will
tf.error.OutOfRangeError after that. Default is -1, meaning it will
exhaust all the examples in the dataset.
exhaust all the examples in the dataset.
enable_tf_data_service: A boolean indicating whether to enable tf.data
service for the input pipeline.
tf_data_service_address: The URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be
overridden by `FLAGS.tf_data_service` flag in the binary.
tf_data_service_job_name: The name of the tf.data service job. This
argument makes it possible for multiple datasets to share the same job.
The default behavior is that the dataset creates anonymous, exclusively
owned jobs.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_data_dir: A str specifying the directory to read/write TFDS data.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_download: A bool to indicate whether to download data using TFDS.
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
...
@@ -74,8 +85,12 @@ class DataConfig(base_config.Config):
...
@@ -74,8 +85,12 @@ class DataConfig(base_config.Config):
cache
:
bool
=
False
cache
:
bool
=
False
cycle_length
:
int
=
8
cycle_length
:
int
=
8
block_length
:
int
=
1
block_length
:
int
=
1
deterministic
:
Optional
[
bool
]
=
None
sharding
:
bool
=
True
sharding
:
bool
=
True
examples_consume
:
int
=
-
1
examples_consume
:
int
=
-
1
enable_tf_data_service
:
bool
=
False
tf_data_service_address
:
Optional
[
str
]
=
None
tf_data_service_job_name
:
Optional
[
str
]
=
None
tfds_data_dir
:
str
=
""
tfds_data_dir
:
str
=
""
tfds_download
:
bool
=
False
tfds_download
:
bool
=
False
tfds_as_supervised
:
bool
=
False
tfds_as_supervised
:
bool
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment