Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5a3b762c
Commit
5a3b762c
authored
Nov 18, 2019
by
Allen Wang
Committed by
A. Unique TensorFlower
Nov 18, 2019
Browse files
Internal change
PiperOrigin-RevId: 281063737
parent
aedb9802
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
104 additions
and
75 deletions
+104
-75
official/modeling/training/distributed_executor.py
official/modeling/training/distributed_executor.py
+2
-74
official/utils/hyperparams_flags.py
official/utils/hyperparams_flags.py
+100
-0
official/vision/detection/main.py
official/vision/detection/main.py
+2
-1
No files found.
official/modeling/training/distributed_executor.py
View file @
5a3b762c
...
@@ -22,9 +22,10 @@ from __future__ import print_function
...
@@ -22,9 +22,10 @@ from __future__ import print_function
import
json
import
json
import
os
import
os
import
numpy
as
np
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
...
@@ -35,79 +36,6 @@ from official.utils.misc import tpu_lib
...
@@ -35,79 +36,6 @@ from official.utils.misc import tpu_lib
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
def
define_common_hparams_flags
():
"""Define the common flags across models."""
flags
.
DEFINE_string
(
'model_dir'
,
default
=
None
,
help
=
(
'The directory where the model and training/evaluation summaries'
'are stored.'
))
flags
.
DEFINE_integer
(
'train_batch_size'
,
default
=
None
,
help
=
'Batch size for training.'
)
flags
.
DEFINE_integer
(
'eval_batch_size'
,
default
=
None
,
help
=
'Batch size for evaluation.'
)
flags
.
DEFINE_string
(
'precision'
,
default
=
None
,
help
=
(
'Precision to use; one of: {bfloat16, float32}'
))
flags
.
DEFINE_string
(
'config_file'
,
default
=
None
,
help
=
(
'A YAML file which specifies overrides. Note that this file can be '
'used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, the one in '
'`--params_override` will be used finally.'
))
flags
.
DEFINE_string
(
'params_override'
,
default
=
None
,
help
=
(
'a YAML/JSON string or a YAML file which specifies additional '
'overrides over the default parameters and those specified in '
'`--config_file`. Note that this is supposed to be used only to '
'override the model parameters, but not the parameters like TPU '
'specific flags. One canonical use case of `--config_file` and '
'`--params_override` is users first define a template config file '
'using `--config_file`, then use `--params_override` to adjust the '
'minimal set of tuning parameters, for example setting up different'
' `train_batch_size`. '
'The final override order of parameters: default_model_params --> '
'params from config_file --> params in params_override.'
'See also the help message of `--config_file`.'
))
flags
.
DEFINE_string
(
'strategy_type'
,
'mirrored'
,
'Type of distribute strategy.'
'One of mirrored, tpu and multiworker.'
)
def
initialize_common_flags
():
"""Define the common flags across models."""
define_common_hparams_flags
()
flags
.
DEFINE_string
(
'tpu'
,
default
=
None
,
help
=
'The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.'
)
# Parameters for MultiWorkerMirroredStrategy
flags
.
DEFINE_string
(
'worker_hosts'
,
default
=
None
,
help
=
'Comma-separated list of worker ip:port pairs for running '
'multi-worker models with distribution strategy. The user would '
'start the program on each host with identical value for this flag.'
)
flags
.
DEFINE_integer
(
'task_index'
,
0
,
'If multi-worker training, the task_index of this worker.'
)
flags
.
DEFINE_integer
(
'save_checkpoint_freq'
,
None
,
'Number of steps to save checkpoint.'
)
def
strategy_flags_dict
():
def
strategy_flags_dict
():
"""Returns TPU related flags in a dictionary."""
"""Returns TPU related flags in a dictionary."""
return
{
return
{
...
...
official/utils/hyperparams_flags.py
0 → 100644
View file @
5a3b762c
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common flags for importing hyperparameters."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
from
absl
import
flags
FLAGS
=
flags
.
FLAGS
def
define_common_hparams_flags
():
"""Define the common flags across models."""
flags
.
DEFINE_string
(
'model_dir'
,
default
=
None
,
help
=
(
'The directory where the model and training/evaluation summaries'
'are stored.'
))
flags
.
DEFINE_integer
(
'train_batch_size'
,
default
=
None
,
help
=
'Batch size for training.'
)
flags
.
DEFINE_integer
(
'eval_batch_size'
,
default
=
None
,
help
=
'Batch size for evaluation.'
)
flags
.
DEFINE_string
(
'precision'
,
default
=
None
,
help
=
(
'Precision to use; one of: {bfloat16, float32}'
))
flags
.
DEFINE_string
(
'config_file'
,
default
=
None
,
help
=
(
'A YAML file which specifies overrides. Note that this file can be '
'used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, the one in '
'`--params_override` will be used finally.'
))
flags
.
DEFINE_string
(
'params_override'
,
default
=
None
,
help
=
(
'a YAML/JSON string or a YAML file which specifies additional '
'overrides over the default parameters and those specified in '
'`--config_file`. Note that this is supposed to be used only to '
'override the model parameters, but not the parameters like TPU '
'specific flags. One canonical use case of `--config_file` and '
'`--params_override` is users first define a template config file '
'using `--config_file`, then use `--params_override` to adjust the '
'minimal set of tuning parameters, for example setting up different'
' `train_batch_size`. '
'The final override order of parameters: default_model_params --> '
'params from config_file --> params in params_override.'
'See also the help message of `--config_file`.'
))
flags
.
DEFINE_string
(
'strategy_type'
,
'mirrored'
,
'Type of distribute strategy.'
'One of mirrored, tpu and multiworker.'
)
def
initialize_common_flags
():
"""Define the common flags across models."""
key_flags
=
[]
define_common_hparams_flags
()
flags
.
DEFINE_string
(
'tpu'
,
default
=
None
,
help
=
'The Cloud TPU to use for training. This should be either the name '
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
'url.'
)
# Parameters for MultiWorkerMirroredStrategy
flags
.
DEFINE_string
(
'worker_hosts'
,
default
=
None
,
help
=
'Comma-separated list of worker ip:port pairs for running '
'multi-worker models with distribution strategy. The user would '
'start the program on each host with identical value for this flag.'
)
flags
.
DEFINE_integer
(
'task_index'
,
0
,
'If multi-worker training, the task_index of this worker.'
)
flags
.
DEFINE_integer
(
'save_checkpoint_freq'
,
None
,
'Number of steps to save checkpoint.'
)
return
key_flags
official/vision/detection/main.py
View file @
5a3b762c
...
@@ -29,13 +29,14 @@ import tensorflow.compat.v2 as tf
...
@@ -29,13 +29,14 @@ import tensorflow.compat.v2 as tf
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.training
import
distributed_executor
as
executor
from
official.modeling.training
import
distributed_executor
as
executor
from
official.utils
import
hyperparams_flags
from
official.vision.detection.configs
import
factory
as
config_factory
from
official.vision.detection.configs
import
factory
as
config_factory
from
official.vision.detection.dataloader
import
input_reader
from
official.vision.detection.dataloader
import
input_reader
from
official.vision.detection.dataloader
import
mode_keys
as
ModeKeys
from
official.vision.detection.dataloader
import
mode_keys
as
ModeKeys
from
official.vision.detection.executor.detection_executor
import
DetectionDistributedExecutor
from
official.vision.detection.executor.detection_executor
import
DetectionDistributedExecutor
from
official.vision.detection.modeling
import
factory
as
model_factory
from
official.vision.detection.modeling
import
factory
as
model_factory
executor
.
initialize_common_flags
()
hyperparams_flags
.
initialize_common_flags
()
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
'mode'
,
'mode'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment