Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
MLPerf_ResNet50_tensorflow
Commits
05631eec
Commit
05631eec
authored
Apr 10, 2023
by
liangjing
Browse files
version 1
parent
7e0391d9
Changes
112
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
974 additions
and
0 deletions
+974
-0
tf2_common/utils/flags/__pycache__/_distribution.cpython-36.pyc
...mmon/utils/flags/__pycache__/_distribution.cpython-36.pyc
+0
-0
tf2_common/utils/flags/__pycache__/_distribution.cpython-38.pyc
...mmon/utils/flags/__pycache__/_distribution.cpython-38.pyc
+0
-0
tf2_common/utils/flags/__pycache__/_misc.cpython-36.pyc
tf2_common/utils/flags/__pycache__/_misc.cpython-36.pyc
+0
-0
tf2_common/utils/flags/__pycache__/_misc.cpython-38.pyc
tf2_common/utils/flags/__pycache__/_misc.cpython-38.pyc
+0
-0
tf2_common/utils/flags/__pycache__/_performance.cpython-36.pyc
...ommon/utils/flags/__pycache__/_performance.cpython-36.pyc
+0
-0
tf2_common/utils/flags/__pycache__/_performance.cpython-38.pyc
...ommon/utils/flags/__pycache__/_performance.cpython-38.pyc
+0
-0
tf2_common/utils/flags/__pycache__/core.cpython-36.pyc
tf2_common/utils/flags/__pycache__/core.cpython-36.pyc
+0
-0
tf2_common/utils/flags/__pycache__/core.cpython-38.pyc
tf2_common/utils/flags/__pycache__/core.cpython-38.pyc
+0
-0
tf2_common/utils/flags/_base.py
tf2_common/utils/flags/_base.py
+163
-0
tf2_common/utils/flags/_benchmark.py
tf2_common/utils/flags/_benchmark.py
+105
-0
tf2_common/utils/flags/_conventions.py
tf2_common/utils/flags/_conventions.py
+54
-0
tf2_common/utils/flags/_device.py
tf2_common/utils/flags/_device.py
+85
-0
tf2_common/utils/flags/_distribution.py
tf2_common/utils/flags/_distribution.py
+54
-0
tf2_common/utils/flags/_misc.py
tf2_common/utils/flags/_misc.py
+50
-0
tf2_common/utils/flags/_performance.py
tf2_common/utils/flags/_performance.py
+331
-0
tf2_common/utils/flags/core.py
tf2_common/utils/flags/core.py
+132
-0
tf2_common/utils/logs/__init__.py
tf2_common/utils/logs/__init__.py
+0
-0
tf2_common/utils/logs/__pycache__/__init__.cpython-36.pyc
tf2_common/utils/logs/__pycache__/__init__.cpython-36.pyc
+0
-0
tf2_common/utils/logs/__pycache__/__init__.cpython-38.pyc
tf2_common/utils/logs/__pycache__/__init__.cpython-38.pyc
+0
-0
tf2_common/utils/logs/__pycache__/cloud_lib.cpython-36.pyc
tf2_common/utils/logs/__pycache__/cloud_lib.cpython-36.pyc
+0
-0
No files found.
tf2_common/utils/flags/__pycache__/_distribution.cpython-36.pyc
0 → 100644
View file @
05631eec
File added
tf2_common/utils/flags/__pycache__/_distribution.cpython-38.pyc
0 → 100644
View file @
05631eec
File added
tf2_common/utils/flags/__pycache__/_misc.cpython-36.pyc
0 → 100644
View file @
05631eec
File added
tf2_common/utils/flags/__pycache__/_misc.cpython-38.pyc
0 → 100644
View file @
05631eec
File added
tf2_common/utils/flags/__pycache__/_performance.cpython-36.pyc
0 → 100644
View file @
05631eec
File added
tf2_common/utils/flags/__pycache__/_performance.cpython-38.pyc
0 → 100644
View file @
05631eec
File added
tf2_common/utils/flags/__pycache__/core.cpython-36.pyc
0 → 100644
View file @
05631eec
File added
tf2_common/utils/flags/__pycache__/core.cpython-38.pyc
0 → 100644
View file @
05631eec
File added
tf2_common/utils/flags/_base.py
0 → 100644
View file @
05631eec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags which will be nearly universal across models."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
import
tensorflow
as
tf
from
tf2_common.utils.flags._conventions
import
help_wrap
from
tf2_common.utils.logs
import
hooks_helper
def
define_base
(
data_dir
=
True
,
model_dir
=
True
,
clean
=
False
,
train_epochs
=
False
,
epochs_between_evals
=
False
,
stop_threshold
=
False
,
batch_size
=
True
,
num_gpu
=
False
,
hooks
=
False
,
export_dir
=
False
,
distribution_strategy
=
False
,
run_eagerly
=
False
):
"""Register base flags.
Args:
data_dir: Create a flag for specifying the input data directory.
model_dir: Create a flag for specifying the model file directory.
clean: Create a flag for removing the model_dir.
train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other
eval metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size.
num_gpu: Create a flag to specify the number of GPUs used.
hooks: Create a flag to specify hooks for logging.
export_dir: Create a flag to specify where a SavedModel should be exported.
distribution_strategy: Create a flag to specify which Distribution Strategy
to use.
run_eagerly: Create a flag to specify to run eagerly op by op.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
data_dir
:
flags
.
DEFINE_string
(
name
=
"data_dir"
,
short_name
=
"dd"
,
default
=
"/tmp"
,
help
=
help_wrap
(
"The location of the input data."
))
key_flags
.
append
(
"data_dir"
)
if
model_dir
:
flags
.
DEFINE_string
(
name
=
"model_dir"
,
short_name
=
"md"
,
default
=
"/tmp"
,
help
=
help_wrap
(
"The location of the model checkpoint files."
))
key_flags
.
append
(
"model_dir"
)
if
clean
:
flags
.
DEFINE_boolean
(
name
=
"clean"
,
default
=
False
,
help
=
help_wrap
(
"If set, model_dir will be removed if it exists."
))
key_flags
.
append
(
"clean"
)
if
train_epochs
:
flags
.
DEFINE_integer
(
name
=
"train_epochs"
,
short_name
=
"te"
,
default
=
1
,
help
=
help_wrap
(
"The number of epochs used to train."
))
key_flags
.
append
(
"train_epochs"
)
if
epochs_between_evals
:
flags
.
DEFINE_integer
(
name
=
"epochs_between_evals"
,
short_name
=
"ebe"
,
default
=
1
,
help
=
help_wrap
(
"The number of training epochs to run between "
"evaluations."
))
key_flags
.
append
(
"epochs_between_evals"
)
if
stop_threshold
:
flags
.
DEFINE_float
(
name
=
"stop_threshold"
,
short_name
=
"st"
,
default
=
None
,
help
=
help_wrap
(
"If passed, training will stop at the earlier of "
"train_epochs and when the evaluation metric is "
"greater than or equal to stop_threshold."
))
if
batch_size
:
flags
.
DEFINE_integer
(
name
=
"batch_size"
,
short_name
=
"bs"
,
default
=
32
,
help
=
help_wrap
(
"Batch size for training and evaluation. When using "
"multiple gpus, this is the global batch size for "
"all devices. For example, if the batch size is 32 "
"and there are 4 GPUs, each GPU will get 8 examples on "
"each step."
))
key_flags
.
append
(
"batch_size"
)
if
num_gpu
:
flags
.
DEFINE_integer
(
name
=
"num_gpus"
,
short_name
=
"ng"
,
default
=
1
,
help
=
help_wrap
(
"How many GPUs to use at each worker with the "
"DistributionStrategies API. The default is 1."
))
if
run_eagerly
:
flags
.
DEFINE_boolean
(
name
=
"run_eagerly"
,
default
=
False
,
help
=
"Run the model op by op without building a model function."
)
if
hooks
:
# Construct a pretty summary of hooks.
hook_list_str
=
(
u
"
\ufeff
Hook:
\n
"
+
u
"
\n
"
.
join
([
u
"
\ufeff
{}"
.
format
(
key
)
for
key
in
hooks_helper
.
HOOKS
]))
flags
.
DEFINE_list
(
name
=
"hooks"
,
short_name
=
"hk"
,
default
=
"LoggingTensorHook"
,
help
=
help_wrap
(
u
"A list of (case insensitive) strings to specify the names of "
u
"training hooks.
\n
{}
\n\ufeff
Example: `--hooks ProfilerHook,"
u
"ExamplesPerSecondHook`
\n
See official.utils.logs.hooks_helper "
u
"for details."
.
format
(
hook_list_str
))
)
key_flags
.
append
(
"hooks"
)
if
export_dir
:
flags
.
DEFINE_string
(
name
=
"export_dir"
,
short_name
=
"ed"
,
default
=
None
,
help
=
help_wrap
(
"If set, a SavedModel serialization of the model will "
"be exported to this directory at the end of training. "
"See the README for more details and relevant links."
)
)
key_flags
.
append
(
"export_dir"
)
if
distribution_strategy
:
flags
.
DEFINE_string
(
name
=
"distribution_strategy"
,
short_name
=
"ds"
,
default
=
"mirrored"
,
help
=
help_wrap
(
"The Distribution Strategy to use for training. "
"Accepted values are 'off', 'one_device', "
"'mirrored', 'parameter_server', 'collective', "
"case insensitive. 'off' means not to use "
"Distribution Strategy; 'default' means to choose "
"from `MirroredStrategy` or `OneDeviceStrategy` "
"according to the number of GPUs."
)
)
return
key_flags
def
get_num_gpus
(
flags_obj
):
"""Treat num_gpus=-1 as 'use all'."""
if
flags_obj
.
num_gpus
!=
-
1
:
return
flags_obj
.
num_gpus
from
tensorflow.python.client
import
device_lib
# pylint: disable=g-import-not-at-top
local_device_protos
=
device_lib
.
list_local_devices
()
return
sum
([
1
for
d
in
local_device_protos
if
d
.
device_type
==
"GPU"
])
tf2_common/utils/flags/_benchmark.py
0 → 100644
View file @
05631eec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags for benchmarking models."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
from
tf2_common.utils.flags._conventions
import
help_wrap
def
define_benchmark
(
benchmark_log_dir
=
True
,
bigquery_uploader
=
True
):
"""Register benchmarking flags.
Args:
benchmark_log_dir: Create a flag to specify location for benchmark logging.
bigquery_uploader: Create flags for uploading results to BigQuery.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
flags
.
DEFINE_enum
(
name
=
"benchmark_logger_type"
,
default
=
"BaseBenchmarkLogger"
,
enum_values
=
[
"BaseBenchmarkLogger"
,
"BenchmarkFileLogger"
,
"BenchmarkBigQueryLogger"
],
help
=
help_wrap
(
"The type of benchmark logger to use. Defaults to using "
"BaseBenchmarkLogger which logs to STDOUT. Different "
"loggers will require other flags to be able to work."
))
flags
.
DEFINE_string
(
name
=
"benchmark_test_id"
,
short_name
=
"bti"
,
default
=
None
,
help
=
help_wrap
(
"The unique test ID of the benchmark run. It could be the "
"combination of key parameters. It is hardware "
"independent and could be used compare the performance "
"between different test runs. This flag is designed for "
"human consumption, and does not have any impact within "
"the system."
))
flags
.
DEFINE_integer
(
name
=
'log_steps'
,
default
=
100
,
help
=
'For every log_steps, we log the timing information such as '
'examples per second. Besides, for every log_steps, we store the '
'timestamp of a batch end.'
)
if
benchmark_log_dir
:
flags
.
DEFINE_string
(
name
=
"benchmark_log_dir"
,
short_name
=
"bld"
,
default
=
None
,
help
=
help_wrap
(
"The location of the benchmark logging."
)
)
if
bigquery_uploader
:
flags
.
DEFINE_string
(
name
=
"gcp_project"
,
short_name
=
"gp"
,
default
=
None
,
help
=
help_wrap
(
"The GCP project name where the benchmark will be uploaded."
))
flags
.
DEFINE_string
(
name
=
"bigquery_data_set"
,
short_name
=
"bds"
,
default
=
"test_benchmark"
,
help
=
help_wrap
(
"The Bigquery dataset name where the benchmark will be uploaded."
))
flags
.
DEFINE_string
(
name
=
"bigquery_run_table"
,
short_name
=
"brt"
,
default
=
"benchmark_run"
,
help
=
help_wrap
(
"The Bigquery table name where the benchmark run "
"information will be uploaded."
))
flags
.
DEFINE_string
(
name
=
"bigquery_run_status_table"
,
short_name
=
"brst"
,
default
=
"benchmark_run_status"
,
help
=
help_wrap
(
"The Bigquery table name where the benchmark run "
"status information will be uploaded."
))
flags
.
DEFINE_string
(
name
=
"bigquery_metric_table"
,
short_name
=
"bmt"
,
default
=
"benchmark_metric"
,
help
=
help_wrap
(
"The Bigquery table name where the benchmark metric "
"information will be uploaded."
))
@
flags
.
multi_flags_validator
(
[
"benchmark_logger_type"
,
"benchmark_log_dir"
],
message
=
"--benchmark_logger_type=BenchmarkFileLogger will require "
"--benchmark_log_dir being set"
)
def
_check_benchmark_log_dir
(
flags_dict
):
benchmark_logger_type
=
flags_dict
[
"benchmark_logger_type"
]
if
benchmark_logger_type
==
"BenchmarkFileLogger"
:
return
flags_dict
[
"benchmark_log_dir"
]
return
True
return
key_flags
tf2_common/utils/flags/_conventions.py
0 → 100644
View file @
05631eec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Central location for shared argparse convention definitions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
sys
import
codecs
import
functools
from
absl
import
app
as
absl_app
from
absl
import
flags
# This codifies help string conventions and makes it easy to update them if
# necessary. Currently the only major effect is that help bodies start on the
# line after flags are listed. All flag definitions should wrap the text bodies
# with help wrap when calling DEFINE_*.
_help_wrap
=
functools
.
partial
(
flags
.
text_wrap
,
length
=
80
,
indent
=
""
,
firstline_indent
=
"
\n
"
)
# Pretty formatting causes issues when utf-8 is not installed on a system.
def
_stdout_utf8
():
try
:
codecs
.
lookup
(
"utf-8"
)
except
LookupError
:
return
False
return
sys
.
stdout
.
encoding
==
"UTF-8"
if
_stdout_utf8
():
help_wrap
=
_help_wrap
else
:
def
help_wrap
(
text
,
*
args
,
**
kwargs
):
return
_help_wrap
(
text
,
*
args
,
**
kwargs
).
replace
(
u
"
\ufeff
"
,
u
""
)
# Replace None with h to also allow -h
absl_app
.
HelpshortFlag
.
SHORT_NAME
=
"h"
tf2_common/utils/flags/_device.py
0 → 100644
View file @
05631eec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags for managing compute devices. Currently only contains TPU flags."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
import
tensorflow
as
tf
from
tf2_common.utils.flags._conventions
import
help_wrap
def
require_cloud_storage
(
flag_names
):
"""Register a validator to check directory flags.
Args:
flag_names: An iterable of strings containing the names of flags to be
checked.
"""
msg
=
"TPU requires GCS path for {}"
.
format
(
", "
.
join
(
flag_names
))
@
flags
.
multi_flags_validator
([
"tpu"
]
+
flag_names
,
message
=
msg
)
def
_path_check
(
flag_values
):
# pylint: disable=missing-docstring
if
flag_values
[
"tpu"
]
is
None
:
return
True
valid_flags
=
True
for
key
in
flag_names
:
if
not
flag_values
[
key
].
startswith
(
"gs://"
):
tf
.
compat
.
v1
.
logging
.
error
(
"{} must be a GCS path."
.
format
(
key
))
valid_flags
=
False
return
valid_flags
def
define_device
(
tpu
=
True
):
"""Register device specific flags.
Args:
tpu: Create flags to specify TPU operation.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
tpu
:
flags
.
DEFINE_string
(
name
=
"tpu"
,
default
=
None
,
help
=
help_wrap
(
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a "
"grpc://ip.address.of.tpu:8470 url. Passing `local` will use the"
"CPU of the local instance instead. (Good for debugging.)"
))
key_flags
.
append
(
"tpu"
)
flags
.
DEFINE_string
(
name
=
"tpu_zone"
,
default
=
None
,
help
=
help_wrap
(
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE "
"project from metadata."
))
flags
.
DEFINE_string
(
name
=
"tpu_gcp_project"
,
default
=
None
,
help
=
help_wrap
(
"[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE "
"project from metadata."
))
flags
.
DEFINE_integer
(
name
=
"num_tpu_shards"
,
default
=
8
,
help
=
help_wrap
(
"Number of shards (TPU chips)."
))
return
key_flags
tf2_common/utils/flags/_distribution.py
0 → 100644
View file @
05631eec
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flags related to distributed execution."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
import
tensorflow
as
tf
from
tf2_common.utils.flags._conventions
import
help_wrap
def
define_distribution
(
worker_hosts
=
True
,
task_index
=
True
):
"""Register distributed execution flags.
Args:
worker_hosts: Create a flag for specifying comma-separated list of workers.
task_index: Create a flag for specifying index of task.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
worker_hosts
:
flags
.
DEFINE_string
(
name
=
'worker_hosts'
,
default
=
None
,
help
=
help_wrap
(
'Comma-separated list of worker ip:port pairs for running '
'multi-worker models with DistributionStrategy. The user would '
'start the program on each host with identical value for this '
'flag.'
))
if
task_index
:
flags
.
DEFINE_integer
(
name
=
'task_index'
,
default
=-
1
,
help
=
help_wrap
(
'If multi-worker training, the task_index of this '
'worker.'
))
return
key_flags
tf2_common/utils/flags/_misc.py
0 → 100644
View file @
05631eec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Misc flags."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
from
tf2_common.utils.flags._conventions
import
help_wrap
def
define_image
(
data_format
=
True
):
"""Register image specific flags.
Args:
data_format: Create a flag to specify image axis convention.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
data_format
:
flags
.
DEFINE_enum
(
name
=
"data_format"
,
short_name
=
"df"
,
default
=
None
,
enum_values
=
[
"channels_first"
,
"channels_last"
],
help
=
help_wrap
(
"A flag to override the data format used in the model. "
"channels_first provides a performance boost on GPU but is not "
"always compatible with CPU. If left unspecified, the data format "
"will be chosen automatically based on whether TensorFlow was "
"built for CPU or GPU."
))
key_flags
.
append
(
"data_format"
)
return
key_flags
tf2_common/utils/flags/_performance.py
0 → 100644
View file @
05631eec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Register flags for optimizing performance."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
multiprocessing
from
absl
import
flags
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
tf2_common.utils.flags._conventions
import
help_wrap
# Map string to TensorFlow dtype
DTYPE_MAP
=
{
"fp16"
:
tf
.
float16
,
"bf16"
:
tf
.
bfloat16
,
"fp32"
:
tf
.
float32
,
}
def
get_tf_dtype
(
flags_obj
):
if
getattr
(
flags_obj
,
"fp16_implementation"
,
None
)
==
"graph_rewrite"
:
# If the graph_rewrite is used, we build the graph with fp32, and let the
# graph rewrite change ops to fp16.
return
tf
.
float32
return
DTYPE_MAP
[
flags_obj
.
dtype
]
def
get_loss_scale
(
flags_obj
,
default_for_fp16
):
if
flags_obj
.
loss_scale
==
"dynamic"
:
return
flags_obj
.
loss_scale
elif
flags_obj
.
loss_scale
is
not
None
:
return
float
(
flags_obj
.
loss_scale
)
elif
flags_obj
.
dtype
==
"fp32"
or
flags_obj
.
dtype
==
"bf16"
:
return
1
# No loss scaling is needed for fp32 and bf16
else
:
assert
flags_obj
.
dtype
==
"fp16"
return
default_for_fp16
def
define_performance
(
num_parallel_calls
=
False
,
inter_op
=
False
,
intra_op
=
False
,
synthetic_data
=
False
,
max_train_steps
=
False
,
dtype
=
False
,
all_reduce_alg
=
False
,
num_packs
=
False
,
tf_gpu_thread_mode
=
False
,
datasets_num_private_threads
=
False
,
datasets_num_parallel_batches
=
False
,
dynamic_loss_scale
=
False
,
fp16_implementation
=
False
,
loss_scale
=
False
,
tf_data_experimental_slack
=
False
,
enable_xla
=
False
,
force_v2_in_keras_compile
=
False
,
training_dataset_cache
=
False
,
training_prefetch_batchs
=
False
,
eval_dataset_cache
=
False
,
eval_prefetch_batchs
=
False
):
"""Register flags for specifying performance tuning arguments.
Args:
num_parallel_calls: Create a flag to specify parallelism of data loading.
inter_op: Create a flag to allow specification of inter op threads.
intra_op: Create a flag to allow specification of intra op threads.
synthetic_data: Create a flag to allow the use of synthetic data.
max_train_steps: Create a flags to allow specification of maximum number
of training steps
dtype: Create flags for specifying dtype.
all_reduce_alg: If set forces a specific algorithm for multi-gpu.
num_packs: If set provides number of packs for MirroredStrategy's cross
device ops.
tf_gpu_thread_mode: gpu_private triggers us of private thread pool.
datasets_num_private_threads: Number of private threads for datasets.
datasets_num_parallel_batches: Determines how many batches to process in
parallel when using map and batch from tf.data.
dynamic_loss_scale: Allow the "loss_scale" flag to take on the value
"dynamic". Only valid if `dtype` is True.
fp16_implementation: Create fp16_implementation flag.
loss_scale: Controls the loss scaling, normally for mixed-precision
training. Can only be turned on if dtype is also True.
tf_data_experimental_slack: Determines whether to enable tf.data's
`experimental_slack` option.
enable_xla: Determines if XLA (auto clustering) is turned on.
force_v2_in_keras_compile: Forces the use of run_distribued path even if not
using a `strategy`. This is not the same as
`tf.distribute.OneDeviceStrategy`
training_dataset_cache: Whether to cache the training dataset on workers.
Typically used to improve training performance when training data is in
remote storage and can fit into worker memory.
training_prefetch_bachs: The number of batchs to prefetch for training.
eval_dataset_cache: Whether to cache the eval dataset on workers.
eval_prefetch_bachs: The number of batchs to prefetch for eval.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
num_parallel_calls
:
flags
.
DEFINE_integer
(
name
=
"num_parallel_calls"
,
short_name
=
"npc"
,
default
=
multiprocessing
.
cpu_count
(),
help
=
help_wrap
(
"The number of records that are processed in parallel "
"during input processing. This can be optimized per "
"data set but for generally homogeneous data sets, "
"should be approximately the number of available CPU "
"cores. (default behavior)"
))
if
inter_op
:
flags
.
DEFINE_integer
(
name
=
"inter_op_parallelism_threads"
,
short_name
=
"inter"
,
default
=
0
,
help
=
help_wrap
(
"Number of inter_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details."
)
)
if
intra_op
:
flags
.
DEFINE_integer
(
name
=
"intra_op_parallelism_threads"
,
short_name
=
"intra"
,
default
=
0
,
help
=
help_wrap
(
"Number of intra_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details."
))
if
synthetic_data
:
flags
.
DEFINE_bool
(
name
=
"use_synthetic_data"
,
short_name
=
"synth"
,
default
=
False
,
help
=
help_wrap
(
"If set, use fake data (zeroes) instead of a real dataset. "
"This mode is useful for performance debugging, as it removes "
"input processing steps, but will not learn anything."
))
if
max_train_steps
:
flags
.
DEFINE_integer
(
name
=
"max_train_steps"
,
short_name
=
"mts"
,
default
=
None
,
help
=
help_wrap
(
"The model will stop training if the global_step reaches this "
"value. If not set, training will run until the specified number "
"of epochs have run as usual. It is generally recommended to set "
"--train_epochs=1 when using this flag."
))
if
dtype
:
flags
.
DEFINE_enum
(
name
=
"dtype"
,
short_name
=
"dt"
,
default
=
"fp32"
,
enum_values
=
DTYPE_MAP
.
keys
(),
help
=
help_wrap
(
"The TensorFlow datatype used for calculations. "
"Variables may be cast to a higher precision on a "
"case-by-case basis for numerical stability."
))
loss_scale_help_text
=
(
"The amount to scale the loss by when the model is run. {}. Before "
"gradients are computed, the loss is multiplied by the loss scale, "
"making all gradients loss_scale times larger. To adjust for this, "
"gradients are divided by the loss scale before being applied to "
"variables. This is mathematically equivalent to training without "
"a loss scale, but the loss scale helps avoid some intermediate "
"gradients from underflowing to zero. If not provided the default "
"for fp16 is 128 and 1 for all other dtypes.{}"
)
if
dynamic_loss_scale
:
loss_scale_help_text
=
loss_scale_help_text
.
format
(
"This can be an int/float or the string 'dynamic'"
,
" The string 'dynamic' can be used to dynamically determine the "
"optimal loss scale during training, but currently this "
"significantly slows down performance"
)
loss_scale_validation_msg
=
(
"loss_scale should be a positive int/float "
"or the string 'dynamic'."
)
else
:
loss_scale_help_text
=
loss_scale_help_text
.
format
(
"This must be an int/float"
,
""
)
loss_scale_validation_msg
=
"loss_scale should be a positive int/float."
if
loss_scale
:
flags
.
DEFINE_string
(
name
=
"loss_scale"
,
short_name
=
"ls"
,
default
=
None
,
help
=
help_wrap
(
loss_scale_help_text
))
@
flags
.
validator
(
flag_name
=
"loss_scale"
,
message
=
loss_scale_validation_msg
)
def
_check_loss_scale
(
loss_scale
):
# pylint: disable=unused-variable
"""Validator to check the loss scale flag is valid."""
if
loss_scale
is
None
:
return
True
# null case is handled in get_loss_scale()
if
loss_scale
==
"dynamic"
and
dynamic_loss_scale
:
return
True
try
:
loss_scale
=
float
(
loss_scale
)
except
ValueError
:
return
False
return
loss_scale
>
0
if
fp16_implementation
:
flags
.
DEFINE_enum
(
name
=
"fp16_implementation"
,
default
=
"keras"
,
enum_values
=
(
"keras', 'graph_rewrite"
),
help
=
help_wrap
(
"When --dtype=fp16, how fp16 should be implemented. This has no "
"impact on correctness. 'keras' uses the "
"tf.keras.mixed_precision API. 'graph_rewrite' uses the "
"tf.train.experimental.enable_mixed_precision_graph_rewrite "
"API."
))
@
flags
.
multi_flags_validator
([
"fp16_implementation"
,
"dtype"
,
"loss_scale"
])
def
_check_fp16_implementation
(
flags_dict
):
"""Validator to check fp16_implementation flag is valid."""
if
(
flags_dict
[
"fp16_implementation"
]
==
"graph_rewrite"
and
flags_dict
[
"dtype"
]
!=
"fp16"
):
raise
flags
.
ValidationError
(
"--fp16_implementation should not be "
"specified unless --dtype=fp16"
)
return
True
if
all_reduce_alg
:
flags
.
DEFINE_string
(
name
=
"all_reduce_alg"
,
short_name
=
"ara"
,
default
=
None
,
help
=
help_wrap
(
"Defines the algorithm to use for performing all-reduce."
"When specified with MirroredStrategy for single "
"worker, this controls "
"tf.contrib.distribute.AllReduceCrossTowerOps. When "
"specified with MultiWorkerMirroredStrategy, this "
"controls "
"tf.distribute.experimental.CollectiveCommunication; "
"valid options are `ring` and `nccl`."
))
if
num_packs
:
flags
.
DEFINE_integer
(
name
=
"num_packs"
,
default
=
1
,
help
=
help_wrap
(
"Sets `num_packs` in the cross device ops used in "
"MirroredStrategy. For details, see "
"tf.distribute.NcclAllReduce."
))
if
tf_gpu_thread_mode
:
flags
.
DEFINE_string
(
name
=
"tf_gpu_thread_mode"
,
short_name
=
"gt_mode"
,
default
=
None
,
help
=
help_wrap
(
"Whether and how the GPU device uses its own threadpool."
)
)
flags
.
DEFINE_integer
(
name
=
"per_gpu_thread_count"
,
short_name
=
"pgtc"
,
default
=
0
,
help
=
help_wrap
(
"The number of threads to use for GPU. Only valid when "
"tf_gpu_thread_mode is not global."
)
)
if
datasets_num_private_threads
:
flags
.
DEFINE_integer
(
name
=
"datasets_num_private_threads"
,
default
=
None
,
help
=
help_wrap
(
"Number of threads for a private threadpool created for all"
"datasets computation.."
)
)
if
datasets_num_parallel_batches
:
flags
.
DEFINE_integer
(
name
=
"datasets_num_parallel_batches"
,
default
=
None
,
help
=
help_wrap
(
"Determines how many batches to process in parallel when using "
"map and batch from tf.data."
)
)
if
training_dataset_cache
:
flags
.
DEFINE_boolean
(
name
=
"training_dataset_cache"
,
default
=
False
,
help
=
help_wrap
(
"Determines whether to cache the training dataset on workers. "
"Typically used to improve training performance when training "
"data is in remote storage and can fit into worker memory."
)
)
if
training_prefetch_batchs
:
flags
.
DEFINE_integer
(
name
=
"training_prefetch_batchs"
,
default
=
tf
.
data
.
experimental
.
AUTOTUNE
,
help
=
help_wrap
(
"The number of batchs to prefetch for the training dataset."
)
)
if
eval_dataset_cache
:
flags
.
DEFINE_boolean
(
name
=
"eval_dataset_cache"
,
default
=
False
,
help
=
help_wrap
(
"Determines whether to cache the eval dataset on workers. "
"Typically used to improve eval performance when eval "
"data is in remote storage and can fit into worker memory."
)
)
if
eval_prefetch_batchs
:
flags
.
DEFINE_integer
(
name
=
"eval_prefetch_batchs"
,
default
=
tf
.
data
.
experimental
.
AUTOTUNE
,
help
=
help_wrap
(
"The number of batchs to prefetch for the eval dataset."
)
)
if
tf_data_experimental_slack
:
flags
.
DEFINE_boolean
(
name
=
"tf_data_experimental_slack"
,
default
=
False
,
help
=
help_wrap
(
"Whether to enable tf.data's `experimental_slack` option."
)
)
if
enable_xla
:
flags
.
DEFINE_boolean
(
name
=
"enable_xla"
,
default
=
False
,
help
=
"Whether to enable XLA auto jit compilation"
)
if
force_v2_in_keras_compile
:
flags
.
DEFINE_boolean
(
name
=
"force_v2_in_keras_compile"
,
default
=
None
,
help
=
"Forces the use of run_distribued path even if not"
"using a `strategy`. This is not the same as"
"`tf.distribute.OneDeviceStrategy`"
)
return
key_flags
tf2_common/utils/flags/core.py
0 → 100644
View file @
05631eec
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Public interface for flag definition.
See _example.py for detailed instructions on defining flags.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
sys
from
six.moves
import
shlex_quote
from
absl
import
app
as
absl_app
from
absl
import
flags
from
tf2_common.utils.flags
import
_base
from
tf2_common.utils.flags
import
_benchmark
from
tf2_common.utils.flags
import
_conventions
from
tf2_common.utils.flags
import
_device
from
tf2_common.utils.flags
import
_distribution
from
tf2_common.utils.flags
import
_misc
from
tf2_common.utils.flags
import
_performance
def
set_defaults
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
flags
.
FLAGS
.
set_default
(
name
=
key
,
value
=
value
)
def
parse_flags
(
argv
=
None
):
"""Reset flags and reparse. Currently only used in testing."""
flags
.
FLAGS
.
unparse_flags
()
absl_app
.
parse_flags_with_usage
(
argv
or
sys
.
argv
)
def
register_key_flags_in_core
(
f
):
"""Defines a function in core.py, and registers its key flags.
absl uses the location of a flags.declare_key_flag() to determine the context
in which a flag is key. By making all declares in core, this allows model
main functions to call flags.adopt_module_key_flags() on core and correctly
chain key flags.
Args:
f: The function to be wrapped
Returns:
The "core-defined" version of the input function.
"""
def
core_fn
(
*
args
,
**
kwargs
):
key_flags
=
f
(
*
args
,
**
kwargs
)
[
flags
.
declare_key_flag
(
fl
)
for
fl
in
key_flags
]
# pylint: disable=expression-not-assigned
return
core_fn
define_base
=
register_key_flags_in_core
(
_base
.
define_base
)
# We have define_base_eager for compatibility, since it used to be a separate
# function from define_base.
define_base_eager
=
define_base
define_benchmark
=
register_key_flags_in_core
(
_benchmark
.
define_benchmark
)
define_device
=
register_key_flags_in_core
(
_device
.
define_device
)
define_image
=
register_key_flags_in_core
(
_misc
.
define_image
)
define_performance
=
register_key_flags_in_core
(
_performance
.
define_performance
)
define_distribution
=
register_key_flags_in_core
(
_distribution
.
define_distribution
)
help_wrap
=
_conventions
.
help_wrap
get_num_gpus
=
_base
.
get_num_gpus
get_tf_dtype
=
_performance
.
get_tf_dtype
get_loss_scale
=
_performance
.
get_loss_scale
DTYPE_MAP
=
_performance
.
DTYPE_MAP
require_cloud_storage
=
_device
.
require_cloud_storage
def
_get_nondefault_flags_as_dict
():
"""Returns the nondefault flags as a dict from flag name to value."""
nondefault_flags
=
{}
for
flag_name
in
flags
.
FLAGS
:
flag_value
=
getattr
(
flags
.
FLAGS
,
flag_name
)
if
(
flag_name
!=
flags
.
FLAGS
[
flag_name
].
short_name
and
flag_value
!=
flags
.
FLAGS
[
flag_name
].
default
):
nondefault_flags
[
flag_name
]
=
flag_value
return
nondefault_flags
def
get_nondefault_flags_as_str
():
"""Returns flags as a string that can be passed as command line arguments.
E.g., returns: "--batch_size=256 --use_synthetic_data" for the following code
block:
```
flags.FLAGS.batch_size = 256
flags.FLAGS.use_synthetic_data = True
print(get_nondefault_flags_as_str())
```
Only flags with nondefault values are returned, as passing default flags as
command line arguments has no effect.
Returns:
A string with the flags, that can be passed as command line arguments to a
program to use the flags.
"""
nondefault_flags
=
_get_nondefault_flags_as_dict
()
flag_strings
=
[]
for
name
,
value
in
sorted
(
nondefault_flags
.
items
()):
if
isinstance
(
value
,
bool
):
flag_str
=
'--{}'
.
format
(
name
)
if
value
else
'--no{}'
.
format
(
name
)
elif
isinstance
(
value
,
list
):
flag_str
=
'--{}={}'
.
format
(
name
,
','
.
join
(
value
))
else
:
flag_str
=
'--{}={}'
.
format
(
name
,
value
)
flag_strings
.
append
(
flag_str
)
return
' '
.
join
(
shlex_quote
(
flag_str
)
for
flag_str
in
flag_strings
)
tf2_common/utils/logs/__init__.py
0 → 100644
View file @
05631eec
tf2_common/utils/logs/__pycache__/__init__.cpython-36.pyc
0 → 100644
View file @
05631eec
File added
tf2_common/utils/logs/__pycache__/__init__.cpython-38.pyc
0 → 100644
View file @
05631eec
File added
tf2_common/utils/logs/__pycache__/cloud_lib.cpython-36.pyc
0 → 100644
View file @
05631eec
File added
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment