Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
33a4c207
Commit
33a4c207
authored
Sep 12, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 12, 2020
Browse files
[Clean up] Consolidate distribution utils.
PiperOrigin-RevId: 331359058
parent
41a1e1d6
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
18 additions
and
30 deletions
+18
-30
official/vision/beta/train.py
official/vision/beta/train.py
+2
-2
official/vision/detection/main.py
official/vision/detection/main.py
+5
-14
official/vision/image_classification/classifier_trainer.py
official/vision/image_classification/classifier_trainer.py
+5
-6
official/vision/image_classification/mnist_main.py
official/vision/image_classification/mnist_main.py
+3
-4
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
...n/image_classification/resnet/resnet_ctl_imagenet_main.py
+3
-4
No files found.
official/vision/beta/train.py
View file @
33a4c207
...
@@ -19,13 +19,13 @@ from absl import app
...
@@ -19,13 +19,13 @@ from absl import app
from
absl
import
flags
from
absl
import
flags
import
gin
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.common
import
flags
as
tfm_flags
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.utils.misc
import
distribution_utils
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -46,7 +46,7 @@ def main(_):
...
@@ -46,7 +46,7 @@ def main(_):
if
params
.
runtime
.
mixed_precision_dtype
:
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
params
.
runtime
.
loss_scale
)
params
.
runtime
.
loss_scale
)
distribution_strategy
=
distribut
ion
_utils
.
get_distribution_strategy
(
distribution_strategy
=
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
num_gpus
=
params
.
runtime
.
num_gpus
,
...
...
official/vision/detection/main.py
View file @
33a4c207
...
@@ -14,28 +14,19 @@
...
@@ -14,28 +14,19 @@
# ==============================================================================
# ==============================================================================
"""Main function to train various object detection models."""
"""Main function to train various object detection models."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
functools
import
functools
import
pprint
import
pprint
# pylint: disable=g-bad-import-order
# Import libraries
import
tensorflow
as
tf
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
# pylint: enable=g-bad-import-order
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.training
import
distributed_executor
as
executor
from
official.modeling.training
import
distributed_executor
as
executor
from
official.utils
import
hyperparams_flags
from
official.utils
import
hyperparams_flags
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.vision.detection.configs
import
factory
as
config_factory
from
official.vision.detection.configs
import
factory
as
config_factory
from
official.vision.detection.dataloader
import
input_reader
from
official.vision.detection.dataloader
import
input_reader
...
@@ -87,9 +78,9 @@ def run_executor(params,
...
@@ -87,9 +78,9 @@ def run_executor(params,
strategy
=
prebuilt_strategy
strategy
=
prebuilt_strategy
else
:
else
:
strategy_config
=
params
.
strategy_config
strategy_config
=
params
.
strategy_config
distribut
ion
_utils
.
configure_cluster
(
strategy_config
.
worker_hosts
,
distribut
e
_utils
.
configure_cluster
(
strategy_config
.
worker_hosts
,
strategy_config
.
task_index
)
strategy_config
.
task_index
)
strategy
=
distribut
ion
_utils
.
get_distribution_strategy
(
strategy
=
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
strategy_type
,
distribution_strategy
=
params
.
strategy_type
,
num_gpus
=
strategy_config
.
num_gpus
,
num_gpus
=
strategy_config
.
num_gpus
,
all_reduce_alg
=
strategy_config
.
all_reduce_alg
,
all_reduce_alg
=
strategy_config
.
all_reduce_alg
,
...
...
official/vision/image_classification/classifier_trainer.py
View file @
33a4c207
...
@@ -23,11 +23,10 @@ from absl import app
...
@@ -23,11 +23,10 @@ from absl import app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.utils
import
hyperparams_flags
from
official.utils
import
hyperparams_flags
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.vision.image_classification
import
callbacks
as
custom_callbacks
from
official.vision.image_classification
import
callbacks
as
custom_callbacks
from
official.vision.image_classification
import
dataset_factory
from
official.vision.image_classification
import
dataset_factory
...
@@ -291,17 +290,17 @@ def train_and_eval(
...
@@ -291,17 +290,17 @@ def train_and_eval(
"""Runs the train and eval path using compile/fit."""
"""Runs the train and eval path using compile/fit."""
logging
.
info
(
'Running train and eval.'
)
logging
.
info
(
'Running train and eval.'
)
distribut
ion
_utils
.
configure_cluster
(
params
.
runtime
.
worker_hosts
,
distribut
e
_utils
.
configure_cluster
(
params
.
runtime
.
worker_hosts
,
params
.
runtime
.
task_index
)
params
.
runtime
.
task_index
)
# Note: for TPUs, strategy and scope should be created before the dataset
# Note: for TPUs, strategy and scope should be created before the dataset
strategy
=
strategy_override
or
distribut
ion
_utils
.
get_distribution_strategy
(
strategy
=
strategy_override
or
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
tpu_address
=
params
.
runtime
.
tpu
)
strategy_scope
=
distribut
ion
_utils
.
get_strategy_scope
(
strategy
)
strategy_scope
=
distribut
e
_utils
.
get_strategy_scope
(
strategy
)
logging
.
info
(
'Detected %d devices.'
,
logging
.
info
(
'Detected %d devices.'
,
strategy
.
num_replicas_in_sync
if
strategy
else
1
)
strategy
.
num_replicas_in_sync
if
strategy
else
1
)
...
...
official/vision/image_classification/mnist_main.py
View file @
33a4c207
...
@@ -25,9 +25,8 @@ from absl import flags
...
@@ -25,9 +25,8 @@ from absl import flags
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
import
tensorflow_datasets
as
tfds
from
official.common
import
distribute_utils
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
model_helpers
from
official.utils.misc
import
model_helpers
from
official.vision.image_classification.resnet
import
common
from
official.vision.image_classification.resnet
import
common
...
@@ -82,12 +81,12 @@ def run(flags_obj, datasets_override=None, strategy_override=None):
...
@@ -82,12 +81,12 @@ def run(flags_obj, datasets_override=None, strategy_override=None):
Returns:
Returns:
Dictionary of training and eval stats.
Dictionary of training and eval stats.
"""
"""
strategy
=
strategy_override
or
distribut
ion
_utils
.
get_distribution_strategy
(
strategy
=
strategy_override
or
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
distribution_strategy
=
flags_obj
.
distribution_strategy
,
num_gpus
=
flags_obj
.
num_gpus
,
num_gpus
=
flags_obj
.
num_gpus
,
tpu_address
=
flags_obj
.
tpu
)
tpu_address
=
flags_obj
.
tpu
)
strategy_scope
=
distribut
ion
_utils
.
get_strategy_scope
(
strategy
)
strategy_scope
=
distribut
e
_utils
.
get_strategy_scope
(
strategy
)
mnist
=
tfds
.
builder
(
'mnist'
,
data_dir
=
flags_obj
.
data_dir
)
mnist
=
tfds
.
builder
(
'mnist'
,
data_dir
=
flags_obj
.
data_dir
)
if
flags_obj
.
download
:
if
flags_obj
.
download
:
...
...
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
View file @
33a4c207
...
@@ -23,10 +23,9 @@ from absl import flags
...
@@ -23,10 +23,9 @@ from absl import flags
from
absl
import
logging
from
absl
import
logging
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
model_helpers
from
official.utils.misc
import
model_helpers
from
official.vision.image_classification.resnet
import
common
from
official.vision.image_classification.resnet
import
common
...
@@ -117,7 +116,7 @@ def run(flags_obj):
...
@@ -117,7 +116,7 @@ def run(flags_obj):
else
'channels_last'
)
else
'channels_last'
)
tf
.
keras
.
backend
.
set_image_data_format
(
data_format
)
tf
.
keras
.
backend
.
set_image_data_format
(
data_format
)
strategy
=
distribut
ion
_utils
.
get_distribution_strategy
(
strategy
=
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
distribution_strategy
=
flags_obj
.
distribution_strategy
,
num_gpus
=
flags_obj
.
num_gpus
,
num_gpus
=
flags_obj
.
num_gpus
,
all_reduce_alg
=
flags_obj
.
all_reduce_alg
,
all_reduce_alg
=
flags_obj
.
all_reduce_alg
,
...
@@ -144,7 +143,7 @@ def run(flags_obj):
...
@@ -144,7 +143,7 @@ def run(flags_obj):
flags_obj
.
batch_size
,
flags_obj
.
batch_size
,
flags_obj
.
log_steps
,
flags_obj
.
log_steps
,
logdir
=
flags_obj
.
model_dir
if
flags_obj
.
enable_tensorboard
else
None
)
logdir
=
flags_obj
.
model_dir
if
flags_obj
.
enable_tensorboard
else
None
)
with
distribut
ion
_utils
.
get_strategy_scope
(
strategy
):
with
distribut
e
_utils
.
get_strategy_scope
(
strategy
):
runnable
=
resnet_runnable
.
ResnetRunnable
(
flags_obj
,
time_callback
,
runnable
=
resnet_runnable
.
ResnetRunnable
(
flags_obj
,
time_callback
,
per_epoch_steps
)
per_epoch_steps
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment