Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
420fd1cf
Commit
420fd1cf
authored
Sep 12, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 12, 2020
Browse files
[Clean up] Consolidate distribution utils.
PiperOrigin-RevId: 331359058
parent
c7647f11
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
18 additions
and
30 deletions
+18
-30
official/vision/beta/train.py
official/vision/beta/train.py
+2
-2
official/vision/detection/main.py
official/vision/detection/main.py
+5
-14
official/vision/image_classification/classifier_trainer.py
official/vision/image_classification/classifier_trainer.py
+5
-6
official/vision/image_classification/mnist_main.py
official/vision/image_classification/mnist_main.py
+3
-4
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
...n/image_classification/resnet/resnet_ctl_imagenet_main.py
+3
-4
No files found.
official/vision/beta/train.py
View file @
420fd1cf
...
...
@@ -19,13 +19,13 @@ from absl import app
from
absl
import
flags
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.utils.misc
import
distribution_utils
FLAGS
=
flags
.
FLAGS
...
...
@@ -46,7 +46,7 @@ def main(_):
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
params
.
runtime
.
loss_scale
)
distribution_strategy
=
distribut
ion
_utils
.
get_distribution_strategy
(
distribution_strategy
=
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
...
...
official/vision/detection/main.py
View file @
420fd1cf
...
...
@@ -14,28 +14,19 @@
# ==============================================================================
"""Main function to train various object detection models."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
functools
import
pprint
# pylint: disable=g-bad-import-order
# Import libraries
import
tensorflow
as
tf
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
# pylint: enable=g-bad-import-order
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.training
import
distributed_executor
as
executor
from
official.utils
import
hyperparams_flags
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.vision.detection.configs
import
factory
as
config_factory
from
official.vision.detection.dataloader
import
input_reader
...
...
@@ -87,9 +78,9 @@ def run_executor(params,
strategy
=
prebuilt_strategy
else
:
strategy_config
=
params
.
strategy_config
distribut
ion
_utils
.
configure_cluster
(
strategy_config
.
worker_hosts
,
strategy_config
.
task_index
)
strategy
=
distribut
ion
_utils
.
get_distribution_strategy
(
distribut
e
_utils
.
configure_cluster
(
strategy_config
.
worker_hosts
,
strategy_config
.
task_index
)
strategy
=
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
strategy_type
,
num_gpus
=
strategy_config
.
num_gpus
,
all_reduce_alg
=
strategy_config
.
all_reduce_alg
,
...
...
official/vision/image_classification/classifier_trainer.py
View file @
420fd1cf
...
...
@@ -23,11 +23,10 @@ from absl import app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.modeling
import
hyperparams
from
official.modeling
import
performance
from
official.utils
import
hyperparams_flags
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.vision.image_classification
import
callbacks
as
custom_callbacks
from
official.vision.image_classification
import
dataset_factory
...
...
@@ -291,17 +290,17 @@ def train_and_eval(
"""Runs the train and eval path using compile/fit."""
logging
.
info
(
'Running train and eval.'
)
distribut
ion
_utils
.
configure_cluster
(
params
.
runtime
.
worker_hosts
,
params
.
runtime
.
task_index
)
distribut
e
_utils
.
configure_cluster
(
params
.
runtime
.
worker_hosts
,
params
.
runtime
.
task_index
)
# Note: for TPUs, strategy and scope should be created before the dataset
strategy
=
strategy_override
or
distribut
ion
_utils
.
get_distribution_strategy
(
strategy
=
strategy_override
or
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
strategy_scope
=
distribut
ion
_utils
.
get_strategy_scope
(
strategy
)
strategy_scope
=
distribut
e
_utils
.
get_strategy_scope
(
strategy
)
logging
.
info
(
'Detected %d devices.'
,
strategy
.
num_replicas_in_sync
if
strategy
else
1
)
...
...
official/vision/image_classification/mnist_main.py
View file @
420fd1cf
...
...
@@ -25,9 +25,8 @@ from absl import flags
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow_datasets
as
tfds
from
official.common
import
distribute_utils
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
model_helpers
from
official.vision.image_classification.resnet
import
common
...
...
@@ -82,12 +81,12 @@ def run(flags_obj, datasets_override=None, strategy_override=None):
Returns:
Dictionary of training and eval stats.
"""
strategy
=
strategy_override
or
distribut
ion
_utils
.
get_distribution_strategy
(
strategy
=
strategy_override
or
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
num_gpus
=
flags_obj
.
num_gpus
,
tpu_address
=
flags_obj
.
tpu
)
strategy_scope
=
distribut
ion
_utils
.
get_strategy_scope
(
strategy
)
strategy_scope
=
distribut
e
_utils
.
get_strategy_scope
(
strategy
)
mnist
=
tfds
.
builder
(
'mnist'
,
data_dir
=
flags_obj
.
data_dir
)
if
flags_obj
.
download
:
...
...
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
View file @
420fd1cf
...
...
@@ -23,10 +23,9 @@ from absl import flags
from
absl
import
logging
import
orbit
import
tensorflow
as
tf
from
official.common
import
distribute_utils
from
official.modeling
import
performance
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
model_helpers
from
official.vision.image_classification.resnet
import
common
...
...
@@ -117,7 +116,7 @@ def run(flags_obj):
else
'channels_last'
)
tf
.
keras
.
backend
.
set_image_data_format
(
data_format
)
strategy
=
distribut
ion
_utils
.
get_distribution_strategy
(
strategy
=
distribut
e
_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
num_gpus
=
flags_obj
.
num_gpus
,
all_reduce_alg
=
flags_obj
.
all_reduce_alg
,
...
...
@@ -144,7 +143,7 @@ def run(flags_obj):
flags_obj
.
batch_size
,
flags_obj
.
log_steps
,
logdir
=
flags_obj
.
model_dir
if
flags_obj
.
enable_tensorboard
else
None
)
with
distribut
ion
_utils
.
get_strategy_scope
(
strategy
):
with
distribut
e
_utils
.
get_strategy_scope
(
strategy
):
runnable
=
resnet_runnable
.
ResnetRunnable
(
flags_obj
,
time_callback
,
per_epoch_steps
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment