Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4270e416
Commit
4270e416
authored
Dec 04, 2019
by
Yeqing Li
Committed by
A. Unique TensorFlower
Dec 04, 2019
Browse files
Internal change
PiperOrigin-RevId: 283770513
parent
ad56514f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
109 deletions
+55
-109
official/modeling/training/distributed_executor.py
official/modeling/training/distributed_executor.py
+13
-86
official/utils/hyperparams_flags.py
official/utils/hyperparams_flags.py
+42
-23
No files found.
official/modeling/training/distributed_executor.py
View file @
4270e416
...
@@ -32,32 +32,13 @@ import tensorflow as tf
...
@@ -32,32 +32,13 @@ import tensorflow as tf
from
typing
import
Optional
,
Dict
,
List
,
Text
,
Callable
,
Union
,
Iterator
,
Any
from
typing
import
Optional
,
Dict
,
List
,
Text
,
Callable
,
Union
,
Iterator
,
Any
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.hyperparams
import
params_dict
from
official.utils.misc
import
tpu_lib
from
official.utils.misc
import
tpu_lib
from
official.utils.misc
import
distribution_utils
from
official.utils
import
hyperparams_flags
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
strategy_flags_dict
=
hyperparams_flags
.
strategy_flags_dict
def
strategy_flags_dict
():
hparam_flags_dict
=
hyperparams_flags
.
hparam_flags_dict
"""Returns TPU related flags in a dictionary."""
return
{
# TPUStrategy related flags.
'tpu'
:
FLAGS
.
tpu
,
# MultiWorkerMirroredStrategy related flags.
'worker_hosts'
:
FLAGS
.
worker_hosts
,
'task_index'
:
FLAGS
.
task_index
,
}
def
hparam_flags_dict
():
"""Returns model params related flags in a dictionary."""
return
{
'data_dir'
:
FLAGS
.
data_dir
,
'model_dir'
:
FLAGS
.
model_dir
,
'train_batch_size'
:
FLAGS
.
train_batch_size
,
'eval_batch_size'
:
FLAGS
.
eval_batch_size
,
'precision'
:
FLAGS
.
precision
,
'config_file'
:
FLAGS
.
config_file
,
'params_override'
:
FLAGS
.
params_override
,
}
def
_save_checkpoint
(
checkpoint
,
model_dir
,
checkpoint_prefix
):
def
_save_checkpoint
(
checkpoint
,
model_dir
,
checkpoint_prefix
):
...
@@ -647,7 +628,6 @@ class DistributedExecutor(object):
...
@@ -647,7 +628,6 @@ class DistributedExecutor(object):
return
NotImplementedError
(
'Unimplmented function.'
)
return
NotImplementedError
(
'Unimplmented function.'
)
# TODO(yeqing): Add unit test for MultiWorkerMirroredStrategy.
class
ExecutorBuilder
(
object
):
class
ExecutorBuilder
(
object
):
"""Builder of DistributedExecutor.
"""Builder of DistributedExecutor.
...
@@ -692,8 +672,15 @@ class ExecutorBuilder(object):
...
@@ -692,8 +672,15 @@ class ExecutorBuilder(object):
"""
"""
def
__init__
(
self
,
strategy_type
=
None
,
strategy_config
=
None
):
def
__init__
(
self
,
strategy_type
=
None
,
strategy_config
=
None
):
self
.
_strategy_config
=
strategy_config
num_workers
=
distribution_utils
.
configure_cluster
(
self
.
_strategy
=
self
.
_build_strategy
(
strategy_type
)
strategy_config
.
worker_hosts
,
strategy_config
.
task_index
)
self
.
_strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
strategy_type
,
num_gpus
=
strategy_config
.
num_gpus
,
num_workers
=
num_workers
,
all_reduce_alg
=
strategy_config
.
all_reduce_alg
,
num_packs
=
strategy_config
.
num_packs
,
tpu_address
=
strategy_config
.
tpu
)
@
property
@
property
def
strategy
(
self
):
def
strategy
(
self
):
...
@@ -705,66 +692,6 @@ class ExecutorBuilder(object):
...
@@ -705,66 +692,6 @@ class ExecutorBuilder(object):
"""Sets default summary writer for the current thread."""
"""Sets default summary writer for the current thread."""
self
.
_strategy
=
new_strategy
self
.
_strategy
=
new_strategy
def
_build_strategy
(
self
,
strategy_type
):
"""Builds tf.distribute.Strategy instance.
Args:
strategy_type: string. One of 'tpu', 'one_device_gpu', 'mirrored', 'multi_worker_mirrored'.
Returns:
An tf.distribute.Strategy object. Returns None if strategy_type is None.
"""
if
strategy_type
is
None
:
return
None
if
strategy_type
==
'tpu'
:
return
self
.
_build_tpu_strategy
()
elif
strategy_type
==
'one_device_gpu'
:
return
tf
.
distribute
.
OneDeviceStrategy
(
"device:GPU:0"
)
elif
strategy_type
==
'mirrored'
:
return
self
.
_build_mirrored_strategy
()
elif
strategy_type
==
'multi_worker_mirrored'
:
return
self
.
_build_multiworker_mirrored_strategy
()
else
:
raise
NotImplementedError
(
'Unsupport accelerator type "%s"'
%
strategy_type
)
def
_build_mirrored_strategy
(
self
):
"""Builds a MirroredStrategy object."""
return
tf
.
distribute
.
MirroredStrategy
()
def
_build_tpu_strategy
(
self
):
"""Builds a TPUStrategy object."""
tpu
=
self
.
_strategy_config
.
tpu
logging
.
info
(
'Use TPU at %s'
,
tpu
if
tpu
is
not
None
else
''
)
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
tpu
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
return
strategy
def
_build_multiworker_mirrored_strategy
(
self
):
"""Builds a MultiWorkerMirroredStrategy object."""
worker_hosts
=
self
.
_strategy_config
.
worker_hosts
if
worker_hosts
is
not
None
:
# Set TF_CONFIG environment variable
worker_hosts
=
worker_hosts
.
split
(
','
)
task_index
=
self
.
_strategy_config
.
task_index
os
.
environ
[
'TF_CONFIG'
]
=
json
.
dumps
({
'cluster'
:
{
'worker'
:
worker_hosts
},
'task'
:
{
'type'
:
'worker'
,
'index'
:
task_index
}
})
multiworker_strategy
=
(
tf
.
distribute
.
experimental
.
MultiWorkerMirroredStrategy
())
return
multiworker_strategy
def
build_executor
(
self
,
def
build_executor
(
self
,
class_ctor
=
DistributedExecutor
,
class_ctor
=
DistributedExecutor
,
...
...
official/utils/hyperparams_flags.py
View file @
4270e416
...
@@ -20,6 +20,7 @@ from __future__ import division
...
@@ -20,6 +20,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl
import
flags
from
absl
import
flags
from
official.utils.flags
import
core
as
flags_core
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -68,33 +69,51 @@ def define_common_hparams_flags():
...
@@ -68,33 +69,51 @@ def define_common_hparams_flags():
'The final override order of parameters: default_model_params --> '
'The final override order of parameters: default_model_params --> '
'params from config_file --> params in params_override.'
'params from config_file --> params in params_override.'
'See also the help message of `--config_file`.'
))
'See also the help message of `--config_file`.'
))
flags
.
DEFINE_integer
(
'save_checkpoint_freq'
,
None
,
flags
.
DEFINE_string
(
'Number of steps to save checkpoint.'
)
'strategy_type'
,
'mirrored'
,
'Type of distribute strategy.'
'One of mirrored, tpu and multiworker.'
)
def
initialize_common_flags
():
def
initialize_common_flags
():
"""Define the common flags across models."""
"""Define the common flags across models."""
key_flags
=
[]
define_common_hparams_flags
()
define_common_hparams_flags
()
flags_core
.
define_device
(
tpu
=
True
)
flags_core
.
define_base
(
num_gpu
=
True
,
model_dir
=
False
,
data_dir
=
False
,
batch_size
=
False
)
flags_core
.
define_distribution
(
worker_hosts
=
True
,
task_index
=
True
)
flags_core
.
define_performance
(
all_reduce_alg
=
True
,
num_packs
=
True
)
# Reset the default value of num_gpus to zero.
FLAGS
.
num_gpus
=
0
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
'
tpu'
,
'
strategy_type'
,
'mirrored'
,
'Type of distribute strategy.'
default
=
None
,
'One of mirrored, tpu and multiworker.'
)
help
=
'The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.'
)
def
strategy_flags_dict
():
# Parameters for MultiWorkerMirroredStrategy
"""Returns TPU and/or GPU related flags in a dictionary."""
flags
.
DEFINE_string
(
return
{
'worker_hosts'
,
# TPUStrategy related flags.
default
=
None
,
'tpu'
:
FLAGS
.
tpu
,
help
=
'Comma-separated list of worker ip:port pairs for running '
# MultiWorkerMirroredStrategy related flags.
'
multi-worker models with distribution strategy. The user would '
'
all_reduce_alg'
:
FLAGS
.
all_reduce_alg
,
'
start the program on each host with identical value for this flag.'
)
'
worker_hosts'
:
FLAGS
.
worker_hosts
,
flags
.
DEFINE_integer
(
'task_index'
:
FLAGS
.
task_index
,
'task_index'
,
0
,
# MirroredStrategy and OneDeviceStrategy
'
If multi-worker training, the task_index of this worker.'
)
'
num_gpus'
:
FLAGS
.
num_gpus
,
flags
.
DEFINE_integer
(
'save_checkpoint_freq'
,
None
,
'num_packs'
:
FLAGS
.
num_packs
,
'Number of steps to save checkpoint.'
)
}
return
key_flags
def
hparam_flags_dict
():
"""Returns model params related flags in a dictionary."""
return
{
'data_dir'
:
FLAGS
.
data_dir
,
'model_dir'
:
FLAGS
.
model_dir
,
'train_batch_size'
:
FLAGS
.
train_batch_size
,
'eval_batch_size'
:
FLAGS
.
eval_batch_size
,
'precision'
:
FLAGS
.
precision
,
'config_file'
:
FLAGS
.
config_file
,
'params_override'
:
FLAGS
.
params_override
,
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment