Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3d61d6b3
Commit
3d61d6b3
authored
Mar 30, 2023
by
qianyj
Browse files
initial files for ResNet50
parent
d3a70caf
Changes
166
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1474 additions
and
0 deletions
+1474
-0
official/utils/flags/_conventions.py
official/utils/flags/_conventions.py
+50
-0
official/utils/flags/_device.py
official/utils/flags/_device.py
+90
-0
official/utils/flags/_distribution.py
official/utils/flags/_distribution.py
+52
-0
official/utils/flags/_misc.py
official/utils/flags/_misc.py
+48
-0
official/utils/flags/_performance.py
official/utils/flags/_performance.py
+294
-0
official/utils/flags/core.py
official/utils/flags/core.py
+130
-0
official/utils/flags/flags_test.py
official/utils/flags/flags_test.py
+162
-0
official/utils/flags/guidelines.md
official/utils/flags/guidelines.md
+65
-0
official/utils/hyperparams_flags.py
official/utils/hyperparams_flags.py
+123
-0
official/utils/misc/__init__.py
official/utils/misc/__init__.py
+14
-0
official/utils/misc/__pycache__/__init__.cpython-37.pyc
official/utils/misc/__pycache__/__init__.cpython-37.pyc
+0
-0
official/utils/misc/__pycache__/keras_utils.cpython-37.pyc
official/utils/misc/__pycache__/keras_utils.cpython-37.pyc
+0
-0
official/utils/misc/__pycache__/model_helpers.cpython-37.pyc
official/utils/misc/__pycache__/model_helpers.cpython-37.pyc
+0
-0
official/utils/misc/keras_utils.py
official/utils/misc/keras_utils.py
+211
-0
official/utils/misc/model_helpers.py
official/utils/misc/model_helpers.py
+94
-0
official/utils/misc/model_helpers_test.py
official/utils/misc/model_helpers_test.py
+127
-0
official/vision/image_classification/resnet/__init__.py
official/vision/image_classification/resnet/__init__.py
+14
-0
official/vision/image_classification/resnet/__pycache__/__init__.cpython-37.pyc
...classification/resnet/__pycache__/__init__.cpython-37.pyc
+0
-0
official/vision/image_classification/resnet/__pycache__/common.cpython-37.pyc
...e_classification/resnet/__pycache__/common.cpython-37.pyc
+0
-0
official/vision/image_classification/resnet/__pycache__/imagenet_preprocessing.cpython-37.pyc
.../resnet/__pycache__/imagenet_preprocessing.cpython-37.pyc
+0
-0
No files found.
official/utils/flags/_conventions.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Central location for shared argparse convention definitions."""
import
sys
import
codecs
import
functools
from
absl
import
app
as
absl_app
from
absl
import
flags
# This codifies help string conventions and makes it easy to update them if
# necessary. Currently the only major effect is that help bodies start on the
# line after flags are listed. All flag definitions should wrap the text bodies
# with help wrap when calling DEFINE_*.
_help_wrap
=
functools
.
partial
(
flags
.
text_wrap
,
length
=
80
,
indent
=
""
,
firstline_indent
=
"
\n
"
)
# Pretty formatting causes issues when utf-8 is not installed on a system.
def
_stdout_utf8
():
try
:
codecs
.
lookup
(
"utf-8"
)
except
LookupError
:
return
False
return
getattr
(
sys
.
stdout
,
"encoding"
,
""
)
==
"UTF-8"
if
_stdout_utf8
():
help_wrap
=
_help_wrap
else
:
def
help_wrap
(
text
,
*
args
,
**
kwargs
):
return
_help_wrap
(
text
,
*
args
,
**
kwargs
).
replace
(
u
"
\ufeff
"
,
u
""
)
# Replace None with h to also allow -h
absl_app
.
HelpshortFlag
.
SHORT_NAME
=
"h"
official/utils/flags/_device.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flags for managing compute devices. Currently only contains TPU flags."""
from
absl
import
flags
from
absl
import
logging
from
official.utils.flags._conventions
import
help_wrap
def
require_cloud_storage
(
flag_names
):
"""Register a validator to check directory flags.
Args:
flag_names: An iterable of strings containing the names of flags to be
checked.
"""
msg
=
"TPU requires GCS path for {}"
.
format
(
", "
.
join
(
flag_names
))
@
flags
.
multi_flags_validator
([
"tpu"
]
+
flag_names
,
message
=
msg
)
def
_path_check
(
flag_values
):
# pylint: disable=missing-docstring
if
flag_values
[
"tpu"
]
is
None
:
return
True
valid_flags
=
True
for
key
in
flag_names
:
if
not
flag_values
[
key
].
startswith
(
"gs://"
):
logging
.
error
(
"%s must be a GCS path."
,
key
)
valid_flags
=
False
return
valid_flags
def
define_device
(
tpu
=
True
):
"""Register device specific flags.
Args:
tpu: Create flags to specify TPU operation.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
tpu
:
flags
.
DEFINE_string
(
name
=
"tpu"
,
default
=
None
,
help
=
help_wrap
(
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a "
"grpc://ip.address.of.tpu:8470 url. Passing `local` will use the"
"CPU of the local instance instead. (Good for debugging.)"
))
key_flags
.
append
(
"tpu"
)
flags
.
DEFINE_string
(
name
=
"tpu_zone"
,
default
=
None
,
help
=
help_wrap
(
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE "
"project from metadata."
))
flags
.
DEFINE_string
(
name
=
"tpu_gcp_project"
,
default
=
None
,
help
=
help_wrap
(
"[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE "
"project from metadata."
))
flags
.
DEFINE_integer
(
name
=
"num_tpu_shards"
,
default
=
8
,
help
=
help_wrap
(
"Number of shards (TPU chips)."
))
return
key_flags
official/utils/flags/_distribution.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flags related to distributed execution."""
from
absl
import
flags
import
tensorflow
as
tf
from
official.utils.flags._conventions
import
help_wrap
def
define_distribution
(
worker_hosts
=
True
,
task_index
=
True
):
"""Register distributed execution flags.
Args:
worker_hosts: Create a flag for specifying comma-separated list of workers.
task_index: Create a flag for specifying index of task.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
worker_hosts
:
flags
.
DEFINE_string
(
name
=
'worker_hosts'
,
default
=
None
,
help
=
help_wrap
(
'Comma-separated list of worker ip:port pairs for running '
'multi-worker models with DistributionStrategy. The user would '
'start the program on each host with identical value for this '
'flag.'
))
if
task_index
:
flags
.
DEFINE_integer
(
name
=
'task_index'
,
default
=-
1
,
help
=
help_wrap
(
'If multi-worker training, the task_index of this '
'worker.'
))
return
key_flags
official/utils/flags/_misc.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Misc flags."""
from
absl
import
flags
from
official.utils.flags._conventions
import
help_wrap
def
define_image
(
data_format
=
True
):
"""Register image specific flags.
Args:
data_format: Create a flag to specify image axis convention.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
data_format
:
flags
.
DEFINE_enum
(
name
=
"data_format"
,
short_name
=
"df"
,
default
=
None
,
enum_values
=
[
"channels_first"
,
"channels_last"
],
help
=
help_wrap
(
"A flag to override the data format used in the model. "
"channels_first provides a performance boost on GPU but is not "
"always compatible with CPU. If left unspecified, the data format "
"will be chosen automatically based on whether TensorFlow was "
"built for CPU or GPU."
))
key_flags
.
append
(
"data_format"
)
return
key_flags
official/utils/flags/_performance.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Register flags for optimizing performance."""
import
multiprocessing
from
absl
import
flags
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.utils.flags._conventions
import
help_wrap
# Map string to TensorFlow dtype
DTYPE_MAP
=
{
"fp16"
:
tf
.
float16
,
"bf16"
:
tf
.
bfloat16
,
"fp32"
:
tf
.
float32
,
}
def
get_tf_dtype
(
flags_obj
):
if
getattr
(
flags_obj
,
"fp16_implementation"
,
None
)
==
"graph_rewrite"
:
# If the graph_rewrite is used, we build the graph with fp32, and let the
# graph rewrite change ops to fp16.
return
tf
.
float32
return
DTYPE_MAP
[
flags_obj
.
dtype
]
def
get_loss_scale
(
flags_obj
,
default_for_fp16
):
dtype
=
get_tf_dtype
(
flags_obj
)
if
flags_obj
.
loss_scale
==
"dynamic"
:
return
flags_obj
.
loss_scale
elif
flags_obj
.
loss_scale
is
not
None
:
return
float
(
flags_obj
.
loss_scale
)
elif
dtype
==
tf
.
float32
or
dtype
==
tf
.
bfloat16
:
return
1
# No loss scaling is needed for fp32
else
:
assert
dtype
==
tf
.
float16
return
default_for_fp16
def
define_performance
(
num_parallel_calls
=
False
,
inter_op
=
False
,
intra_op
=
False
,
synthetic_data
=
False
,
max_train_steps
=
False
,
dtype
=
False
,
all_reduce_alg
=
False
,
num_packs
=
False
,
tf_gpu_thread_mode
=
False
,
datasets_num_private_threads
=
False
,
datasets_num_parallel_batches
=
False
,
fp16_implementation
=
False
,
loss_scale
=
False
,
tf_data_experimental_slack
=
False
,
enable_xla
=
False
,
training_dataset_cache
=
False
):
"""Register flags for specifying performance tuning arguments.
Args:
num_parallel_calls: Create a flag to specify parallelism of data loading.
inter_op: Create a flag to allow specification of inter op threads.
intra_op: Create a flag to allow specification of intra op threads.
synthetic_data: Create a flag to allow the use of synthetic data.
max_train_steps: Create a flags to allow specification of maximum number of
training steps
dtype: Create flags for specifying dtype.
all_reduce_alg: If set forces a specific algorithm for multi-gpu.
num_packs: If set provides number of packs for MirroredStrategy's cross
device ops.
tf_gpu_thread_mode: gpu_private triggers us of private thread pool.
datasets_num_private_threads: Number of private threads for datasets.
datasets_num_parallel_batches: Determines how many batches to process in
parallel when using map and batch from tf.data.
fp16_implementation: Create fp16_implementation flag.
loss_scale: Controls the loss scaling, normally for mixed-precision
training. Can only be turned on if dtype is also True.
tf_data_experimental_slack: Determines whether to enable tf.data's
`experimental_slack` option.
enable_xla: Determines if XLA (auto clustering) is turned on.
training_dataset_cache: Whether to cache the training dataset on workers.
Typically used to improve training performance when training data is in
remote storage and can fit into worker memory.
Returns:
A list of flags for core.py to marks as key flags.
"""
key_flags
=
[]
if
num_parallel_calls
:
flags
.
DEFINE_integer
(
name
=
"num_parallel_calls"
,
short_name
=
"npc"
,
default
=
multiprocessing
.
cpu_count
(),
help
=
help_wrap
(
"The number of records that are processed in parallel "
"during input processing. This can be optimized per "
"data set but for generally homogeneous data sets, "
"should be approximately the number of available CPU "
"cores. (default behavior)"
))
if
inter_op
:
flags
.
DEFINE_integer
(
name
=
"inter_op_parallelism_threads"
,
short_name
=
"inter"
,
default
=
0
,
help
=
help_wrap
(
"Number of inter_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details."
))
if
intra_op
:
flags
.
DEFINE_integer
(
name
=
"intra_op_parallelism_threads"
,
short_name
=
"intra"
,
default
=
0
,
help
=
help_wrap
(
"Number of intra_op_parallelism_threads to use for CPU. "
"See TensorFlow config.proto for details."
))
if
synthetic_data
:
flags
.
DEFINE_bool
(
name
=
"use_synthetic_data"
,
short_name
=
"synth"
,
default
=
False
,
help
=
help_wrap
(
"If set, use fake data (zeroes) instead of a real dataset. "
"This mode is useful for performance debugging, as it removes "
"input processing steps, but will not learn anything."
))
if
max_train_steps
:
flags
.
DEFINE_integer
(
name
=
"max_train_steps"
,
short_name
=
"mts"
,
default
=
None
,
help
=
help_wrap
(
"The model will stop training if the global_step reaches this "
"value. If not set, training will run until the specified number "
"of epochs have run as usual. It is generally recommended to set "
"--train_epochs=1 when using this flag."
))
if
dtype
:
flags
.
DEFINE_enum
(
name
=
"dtype"
,
short_name
=
"dt"
,
default
=
"fp32"
,
enum_values
=
DTYPE_MAP
.
keys
(),
help
=
help_wrap
(
"The TensorFlow datatype used for calculations. "
"For 16-bit dtypes, variables and certain ops will "
"still be float32 for numeric stability."
))
if
loss_scale
:
flags
.
DEFINE_string
(
name
=
"loss_scale"
,
short_name
=
"ls"
,
default
=
None
,
help
=
help_wrap
(
"The amount to scale the loss by when --dtype=fp16. This can be "
"an int/float or the string 'dynamic'. Before gradients are "
"computed, the loss is multiplied by the loss scale, making all "
"gradients loss_scale times larger. To adjust for this, "
"gradients are divided by the loss scale before being applied to "
"variables. This is mathematically equivalent to training "
"without a loss scale, but the loss scale helps avoid some "
"intermediate gradients from underflowing to zero. The default "
"is 'dynamic', which dynamic determines the optimal loss scale "
"during training."
))
# pylint: disable=unused-variable
@
flags
.
validator
(
flag_name
=
"loss_scale"
,
message
=
"loss_scale should be a positive int/float or the string "
"'dynamic'."
)
def
_check_loss_scale
(
loss_scale
):
"""Validator to check the loss scale flag is valid."""
if
loss_scale
is
None
:
return
True
# null case is handled in get_loss_scale()
if
loss_scale
==
"dynamic"
:
return
True
try
:
loss_scale
=
float
(
loss_scale
)
except
ValueError
:
return
False
return
loss_scale
>
0
# pylint: enable=unused-variable
if
fp16_implementation
:
flags
.
DEFINE_enum
(
name
=
"fp16_implementation"
,
default
=
"keras"
,
enum_values
=
(
"keras', 'graph_rewrite"
),
help
=
help_wrap
(
"When --dtype=fp16, how fp16 should be implemented. This has no "
"impact on correctness. 'keras' uses the "
"tf.keras.mixed_precision API. 'graph_rewrite' uses the "
"tf.compat.v1.mixed_precision."
"enable_mixed_precision_graph_rewrite API."
))
@
flags
.
multi_flags_validator
(
[
"fp16_implementation"
,
"dtype"
,
"loss_scale"
])
def
_check_fp16_implementation
(
flags_dict
):
"""Validator to check fp16_implementation flag is valid."""
if
(
flags_dict
[
"fp16_implementation"
]
==
"graph_rewrite"
and
flags_dict
[
"dtype"
]
!=
"fp16"
):
raise
flags
.
ValidationError
(
"--fp16_implementation should not be "
"specified unless --dtype=fp16"
)
return
True
if
all_reduce_alg
:
flags
.
DEFINE_string
(
name
=
"all_reduce_alg"
,
short_name
=
"ara"
,
default
=
None
,
help
=
help_wrap
(
"Defines the algorithm to use for performing all-reduce."
"When specified with MirroredStrategy for single "
"worker, this controls "
"tf.contrib.distribute.AllReduceCrossTowerOps. When "
"specified with MultiWorkerMirroredStrategy, this "
"controls "
"tf.distribute.experimental.CollectiveCommunication; "
"valid options are `ring` and `nccl`."
))
if
num_packs
:
flags
.
DEFINE_integer
(
name
=
"num_packs"
,
default
=
1
,
help
=
help_wrap
(
"Sets `num_packs` in the cross device ops used in "
"MirroredStrategy. For details, see "
"tf.distribute.NcclAllReduce."
))
if
tf_gpu_thread_mode
:
flags
.
DEFINE_string
(
name
=
"tf_gpu_thread_mode"
,
short_name
=
"gt_mode"
,
default
=
None
,
help
=
help_wrap
(
"Whether and how the GPU device uses its own threadpool."
))
flags
.
DEFINE_integer
(
name
=
"per_gpu_thread_count"
,
short_name
=
"pgtc"
,
default
=
0
,
help
=
help_wrap
(
"The number of threads to use for GPU. Only valid when "
"tf_gpu_thread_mode is not global."
))
if
datasets_num_private_threads
:
flags
.
DEFINE_integer
(
name
=
"datasets_num_private_threads"
,
default
=
None
,
help
=
help_wrap
(
"Number of threads for a private threadpool created for all"
"datasets computation.."
))
if
datasets_num_parallel_batches
:
flags
.
DEFINE_integer
(
name
=
"datasets_num_parallel_batches"
,
default
=
None
,
help
=
help_wrap
(
"Determines how many batches to process in parallel when using "
"map and batch from tf.data."
))
if
training_dataset_cache
:
flags
.
DEFINE_boolean
(
name
=
"training_dataset_cache"
,
default
=
False
,
help
=
help_wrap
(
"Determines whether to cache the training dataset on workers. "
"Typically used to improve training performance when training "
"data is in remote storage and can fit into worker memory."
))
if
tf_data_experimental_slack
:
flags
.
DEFINE_boolean
(
name
=
"tf_data_experimental_slack"
,
default
=
False
,
help
=
help_wrap
(
"Whether to enable tf.data's `experimental_slack` option."
))
if
enable_xla
:
flags
.
DEFINE_boolean
(
name
=
"enable_xla"
,
default
=
False
,
help
=
"Whether to enable XLA auto jit compilation"
)
return
key_flags
official/utils/flags/core.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public interface for flag definition.
See _example.py for detailed instructions on defining flags.
"""
import
sys
from
six.moves
import
shlex_quote
from
absl
import
app
as
absl_app
from
absl
import
flags
from
official.utils.flags
import
_base
from
official.utils.flags
import
_benchmark
from
official.utils.flags
import
_conventions
from
official.utils.flags
import
_device
from
official.utils.flags
import
_distribution
from
official.utils.flags
import
_misc
from
official.utils.flags
import
_performance
def
set_defaults
(
**
kwargs
):
for
key
,
value
in
kwargs
.
items
():
flags
.
FLAGS
.
set_default
(
name
=
key
,
value
=
value
)
def
parse_flags
(
argv
=
None
):
"""Reset flags and reparse. Currently only used in testing."""
flags
.
FLAGS
.
unparse_flags
()
absl_app
.
parse_flags_with_usage
(
argv
or
sys
.
argv
)
def
register_key_flags_in_core
(
f
):
"""Defines a function in core.py, and registers its key flags.
absl uses the location of a flags.declare_key_flag() to determine the context
in which a flag is key. By making all declares in core, this allows model
main functions to call flags.adopt_module_key_flags() on core and correctly
chain key flags.
Args:
f: The function to be wrapped
Returns:
The "core-defined" version of the input function.
"""
def
core_fn
(
*
args
,
**
kwargs
):
key_flags
=
f
(
*
args
,
**
kwargs
)
[
flags
.
declare_key_flag
(
fl
)
for
fl
in
key_flags
]
# pylint: disable=expression-not-assigned
return
core_fn
define_base
=
register_key_flags_in_core
(
_base
.
define_base
)
# We have define_base_eager for compatibility, since it used to be a separate
# function from define_base.
define_base_eager
=
define_base
define_log_steps
=
register_key_flags_in_core
(
_benchmark
.
define_log_steps
)
define_benchmark
=
register_key_flags_in_core
(
_benchmark
.
define_benchmark
)
define_device
=
register_key_flags_in_core
(
_device
.
define_device
)
define_image
=
register_key_flags_in_core
(
_misc
.
define_image
)
define_performance
=
register_key_flags_in_core
(
_performance
.
define_performance
)
define_distribution
=
register_key_flags_in_core
(
_distribution
.
define_distribution
)
help_wrap
=
_conventions
.
help_wrap
get_num_gpus
=
_base
.
get_num_gpus
get_tf_dtype
=
_performance
.
get_tf_dtype
get_loss_scale
=
_performance
.
get_loss_scale
DTYPE_MAP
=
_performance
.
DTYPE_MAP
require_cloud_storage
=
_device
.
require_cloud_storage
def
_get_nondefault_flags_as_dict
():
"""Returns the nondefault flags as a dict from flag name to value."""
nondefault_flags
=
{}
for
flag_name
in
flags
.
FLAGS
:
flag_value
=
getattr
(
flags
.
FLAGS
,
flag_name
)
if
(
flag_name
!=
flags
.
FLAGS
[
flag_name
].
short_name
and
flag_value
!=
flags
.
FLAGS
[
flag_name
].
default
):
nondefault_flags
[
flag_name
]
=
flag_value
return
nondefault_flags
def
get_nondefault_flags_as_str
():
"""Returns flags as a string that can be passed as command line arguments.
E.g., returns: "--batch_size=256 --use_synthetic_data" for the following code
block:
```
flags.FLAGS.batch_size = 256
flags.FLAGS.use_synthetic_data = True
print(get_nondefault_flags_as_str())
```
Only flags with nondefault values are returned, as passing default flags as
command line arguments has no effect.
Returns:
A string with the flags, that can be passed as command line arguments to a
program to use the flags.
"""
nondefault_flags
=
_get_nondefault_flags_as_dict
()
flag_strings
=
[]
for
name
,
value
in
sorted
(
nondefault_flags
.
items
()):
if
isinstance
(
value
,
bool
):
flag_str
=
'--{}'
.
format
(
name
)
if
value
else
'--no{}'
.
format
(
name
)
elif
isinstance
(
value
,
list
):
flag_str
=
'--{}={}'
.
format
(
name
,
','
.
join
(
value
))
else
:
flag_str
=
'--{}={}'
.
format
(
name
,
value
)
flag_strings
.
append
(
flag_str
)
return
' '
.
join
(
shlex_quote
(
flag_str
)
for
flag_str
in
flag_strings
)
official/utils/flags/flags_test.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
absl
import
flags
import
tensorflow
as
tf
from
official.utils.flags
import
core
as
flags_core
# pylint: disable=g-bad-import-order
def
define_flags
():
flags_core
.
define_base
(
clean
=
True
,
num_gpu
=
False
,
stop_threshold
=
True
,
hooks
=
True
,
train_epochs
=
True
,
epochs_between_evals
=
True
)
flags_core
.
define_performance
(
num_parallel_calls
=
True
,
inter_op
=
True
,
intra_op
=
True
,
loss_scale
=
True
,
synthetic_data
=
True
,
dtype
=
True
)
flags_core
.
define_image
()
flags_core
.
define_benchmark
()
class
BaseTester
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
super
(
BaseTester
,
cls
).
setUpClass
()
define_flags
()
def
test_default_setting
(
self
):
"""Test to ensure fields exist and defaults can be set."""
defaults
=
dict
(
data_dir
=
"dfgasf"
,
model_dir
=
"dfsdkjgbs"
,
train_epochs
=
534
,
epochs_between_evals
=
15
,
batch_size
=
256
,
hooks
=
[
"LoggingTensorHook"
],
num_parallel_calls
=
18
,
inter_op_parallelism_threads
=
5
,
intra_op_parallelism_threads
=
10
,
data_format
=
"channels_first"
)
flags_core
.
set_defaults
(
**
defaults
)
flags_core
.
parse_flags
()
for
key
,
value
in
defaults
.
items
():
assert
flags
.
FLAGS
.
get_flag_value
(
name
=
key
,
default
=
None
)
==
value
def
test_benchmark_setting
(
self
):
defaults
=
dict
(
hooks
=
[
"LoggingMetricHook"
],
benchmark_log_dir
=
"/tmp/12345"
,
gcp_project
=
"project_abc"
,
)
flags_core
.
set_defaults
(
**
defaults
)
flags_core
.
parse_flags
()
for
key
,
value
in
defaults
.
items
():
assert
flags
.
FLAGS
.
get_flag_value
(
name
=
key
,
default
=
None
)
==
value
def
test_booleans
(
self
):
"""Test to ensure boolean flags trigger as expected."""
flags_core
.
parse_flags
([
__file__
,
"--use_synthetic_data"
])
assert
flags
.
FLAGS
.
use_synthetic_data
def
test_parse_dtype_info
(
self
):
flags_core
.
parse_flags
([
__file__
,
"--dtype"
,
"fp16"
])
self
.
assertEqual
(
flags_core
.
get_tf_dtype
(
flags
.
FLAGS
),
tf
.
float16
)
self
.
assertEqual
(
flags_core
.
get_loss_scale
(
flags
.
FLAGS
,
default_for_fp16
=
2
),
2
)
flags_core
.
parse_flags
([
__file__
,
"--dtype"
,
"fp16"
,
"--loss_scale"
,
"5"
])
self
.
assertEqual
(
flags_core
.
get_loss_scale
(
flags
.
FLAGS
,
default_for_fp16
=
2
),
5
)
flags_core
.
parse_flags
(
[
__file__
,
"--dtype"
,
"fp16"
,
"--loss_scale"
,
"dynamic"
])
self
.
assertEqual
(
flags_core
.
get_loss_scale
(
flags
.
FLAGS
,
default_for_fp16
=
2
),
"dynamic"
)
flags_core
.
parse_flags
([
__file__
,
"--dtype"
,
"fp32"
])
self
.
assertEqual
(
flags_core
.
get_tf_dtype
(
flags
.
FLAGS
),
tf
.
float32
)
self
.
assertEqual
(
flags_core
.
get_loss_scale
(
flags
.
FLAGS
,
default_for_fp16
=
2
),
1
)
flags_core
.
parse_flags
([
__file__
,
"--dtype"
,
"fp32"
,
"--loss_scale"
,
"5"
])
self
.
assertEqual
(
flags_core
.
get_loss_scale
(
flags
.
FLAGS
,
default_for_fp16
=
2
),
5
)
with
self
.
assertRaises
(
SystemExit
):
flags_core
.
parse_flags
([
__file__
,
"--dtype"
,
"int8"
])
with
self
.
assertRaises
(
SystemExit
):
flags_core
.
parse_flags
(
[
__file__
,
"--dtype"
,
"fp16"
,
"--loss_scale"
,
"abc"
])
def
test_get_nondefault_flags_as_str
(
self
):
defaults
=
dict
(
clean
=
True
,
data_dir
=
"abc"
,
hooks
=
[
"LoggingTensorHook"
],
stop_threshold
=
1.5
,
use_synthetic_data
=
False
)
flags_core
.
set_defaults
(
**
defaults
)
flags_core
.
parse_flags
()
expected_flags
=
""
self
.
assertEqual
(
flags_core
.
get_nondefault_flags_as_str
(),
expected_flags
)
flags
.
FLAGS
.
clean
=
False
expected_flags
+=
"--noclean"
self
.
assertEqual
(
flags_core
.
get_nondefault_flags_as_str
(),
expected_flags
)
flags
.
FLAGS
.
data_dir
=
"xyz"
expected_flags
+=
" --data_dir=xyz"
self
.
assertEqual
(
flags_core
.
get_nondefault_flags_as_str
(),
expected_flags
)
flags
.
FLAGS
.
hooks
=
[
"aaa"
,
"bbb"
,
"ccc"
]
expected_flags
+=
" --hooks=aaa,bbb,ccc"
self
.
assertEqual
(
flags_core
.
get_nondefault_flags_as_str
(),
expected_flags
)
flags
.
FLAGS
.
stop_threshold
=
3.
expected_flags
+=
" --stop_threshold=3.0"
self
.
assertEqual
(
flags_core
.
get_nondefault_flags_as_str
(),
expected_flags
)
flags
.
FLAGS
.
use_synthetic_data
=
True
expected_flags
+=
" --use_synthetic_data"
self
.
assertEqual
(
flags_core
.
get_nondefault_flags_as_str
(),
expected_flags
)
# Assert that explicit setting a flag to its default value does not cause it
# to appear in the string
flags
.
FLAGS
.
use_synthetic_data
=
False
expected_flags
=
expected_flags
[:
-
len
(
" --use_synthetic_data"
)]
self
.
assertEqual
(
flags_core
.
get_nondefault_flags_as_str
(),
expected_flags
)
if
__name__
==
"__main__"
:
unittest
.
main
()
official/utils/flags/guidelines.md
0 → 100644
View file @
3d61d6b3
# Using flags in official models
1.
**All common flags must be incorporated in the models.**
Common flags (i.e. batch_size, model_dir, etc.) are provided by various flag definition functions,
and channeled through
`official.utils.flags.core`
. For instance to define common supervised
learning parameters one could use the following code:
```
$xslt
from absl import app as absl_app
from absl import flags
from official.utils.flags import core as flags_core
def define_flags():
flags_core.define_base()
flags.adopt_key_flags(flags_core)
def main(_):
flags_obj = flags.FLAGS
print(flags_obj)
if __name__ == "__main__"
absl_app.run(main)
```
2.
**Validate flag values.**
See the
[
Validators
](
#validators
)
section for implementation details.
Validators in the official model repo should not access the file system, such as verifying
that files exist, due to the strict ordering requirements.
3.
**Flag values should not be mutated.**
Instead of mutating flag values, use getter functions to return the desired values. An example
getter function is
`get_tf_dtype`
function below:
```
# Map string to TensorFlow dtype
DTYPE_MAP = {
"fp16": tf.float16,
"fp32": tf.float32,
}
def get_tf_dtype(flags_obj):
if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite":
# If the graph_rewrite is used, we build the graph with fp32, and let the
# graph rewrite change ops to fp16.
return tf.float32
return DTYPE_MAP[flags_obj.dtype]
def main(_):
flags_obj = flags.FLAGS()
# Do not mutate flags_obj
# if flags_obj.fp16_implementation == "graph_rewrite":
# flags_obj.dtype = "float32" # Don't do this
print(get_tf_dtype(flags_obj))
...
```
\ No newline at end of file
official/utils/hyperparams_flags.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common flags for importing hyperparameters."""
from
absl
import
flags
from
official.utils.flags
import
core
as
flags_core
FLAGS
=
flags
.
FLAGS
def
define_gin_flags
():
"""Define common gin configurable flags."""
flags
.
DEFINE_multi_string
(
'gin_file'
,
None
,
'List of paths to the config files.'
)
flags
.
DEFINE_multi_string
(
'gin_param'
,
None
,
'Newline separated list of Gin parameter bindings.'
)
def
define_common_hparams_flags
():
"""Define the common flags across models."""
flags
.
DEFINE_string
(
'model_dir'
,
default
=
None
,
help
=
(
'The directory where the model and training/evaluation summaries'
'are stored.'
))
flags
.
DEFINE_integer
(
'train_batch_size'
,
default
=
None
,
help
=
'Batch size for training.'
)
flags
.
DEFINE_integer
(
'eval_batch_size'
,
default
=
None
,
help
=
'Batch size for evaluation.'
)
flags
.
DEFINE_string
(
'precision'
,
default
=
None
,
help
=
(
'Precision to use; one of: {bfloat16, float32}'
))
flags
.
DEFINE_string
(
'config_file'
,
default
=
None
,
help
=
(
'A YAML file which specifies overrides. Note that this file can be '
'used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, the one in '
'`--params_override` will be used finally.'
))
flags
.
DEFINE_string
(
'params_override'
,
default
=
None
,
help
=
(
'a YAML/JSON string or a YAML file which specifies additional '
'overrides over the default parameters and those specified in '
'`--config_file`. Note that this is supposed to be used only to '
'override the model parameters, but not the parameters like TPU '
'specific flags. One canonical use case of `--config_file` and '
'`--params_override` is users first define a template config file '
'using `--config_file`, then use `--params_override` to adjust the '
'minimal set of tuning parameters, for example setting up different'
' `train_batch_size`. '
'The final override order of parameters: default_model_params --> '
'params from config_file --> params in params_override.'
'See also the help message of `--config_file`.'
))
flags
.
DEFINE_integer
(
'save_checkpoint_freq'
,
None
,
'Number of steps to save checkpoint.'
)
def
initialize_common_flags
():
"""Define the common flags across models."""
define_common_hparams_flags
()
flags_core
.
define_device
(
tpu
=
True
)
flags_core
.
define_base
(
num_gpu
=
True
,
model_dir
=
False
,
data_dir
=
False
,
batch_size
=
False
)
flags_core
.
define_distribution
(
worker_hosts
=
True
,
task_index
=
True
)
flags_core
.
define_performance
(
all_reduce_alg
=
True
,
num_packs
=
True
)
# Reset the default value of num_gpus to zero.
FLAGS
.
num_gpus
=
0
flags
.
DEFINE_string
(
'strategy_type'
,
'mirrored'
,
'Type of distribute strategy.'
'One of mirrored, tpu and multiworker.'
)
def
strategy_flags_dict
():
"""Returns TPU and/or GPU related flags in a dictionary."""
return
{
'distribution_strategy'
:
FLAGS
.
strategy_type
,
# TPUStrategy related flags.
'tpu'
:
FLAGS
.
tpu
,
# MultiWorkerMirroredStrategy related flags.
'all_reduce_alg'
:
FLAGS
.
all_reduce_alg
,
'worker_hosts'
:
FLAGS
.
worker_hosts
,
'task_index'
:
FLAGS
.
task_index
,
# MirroredStrategy and OneDeviceStrategy
'num_gpus'
:
FLAGS
.
num_gpus
,
'num_packs'
:
FLAGS
.
num_packs
,
}
def
hparam_flags_dict
():
"""Returns model params related flags in a dictionary."""
return
{
'data_dir'
:
FLAGS
.
data_dir
,
'model_dir'
:
FLAGS
.
model_dir
,
'train_batch_size'
:
FLAGS
.
train_batch_size
,
'eval_batch_size'
:
FLAGS
.
eval_batch_size
,
'precision'
:
FLAGS
.
precision
,
'config_file'
:
FLAGS
.
config_file
,
'params_override'
:
FLAGS
.
params_override
,
}
official/utils/misc/__init__.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/utils/misc/__pycache__/__init__.cpython-37.pyc
0 → 100644
View file @
3d61d6b3
File added
official/utils/misc/__pycache__/keras_utils.cpython-37.pyc
0 → 100644
View file @
3d61d6b3
File added
official/utils/misc/__pycache__/model_helpers.cpython-37.pyc
0 → 100644
View file @
3d61d6b3
File added
official/utils/misc/keras_utils.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for the Keras implementations of models."""
import
multiprocessing
import
os
import
time
from
absl
import
logging
import
tensorflow
as
tf
from
tensorflow.python.eager
import
monitoring
global_batch_size_gauge
=
monitoring
.
IntGauge
(
'/tensorflow/training/global_batch_size'
,
'TF training global batch size'
)
first_batch_time_gauge
=
monitoring
.
IntGauge
(
'/tensorflow/training/first_batch'
,
'TF training start/end time for first batch (unix epoch time in us.'
,
'type'
)
first_batch_start_time
=
first_batch_time_gauge
.
get_cell
(
'start'
)
first_batch_end_time
=
first_batch_time_gauge
.
get_cell
(
'end'
)
class
BatchTimestamp
(
object
):
"""A structure to store batch time stamp."""
def
__init__
(
self
,
batch_index
,
timestamp
):
self
.
batch_index
=
batch_index
self
.
timestamp
=
timestamp
def
__repr__
(
self
):
return
"'BatchTimestamp<batch_index: {}, timestamp: {}>'"
.
format
(
self
.
batch_index
,
self
.
timestamp
)
class
TimeHistory
(
tf
.
keras
.
callbacks
.
Callback
):
"""Callback for Keras models."""
def
__init__
(
self
,
batch_size
,
log_steps
,
initial_step
=
0
,
logdir
=
None
):
"""Callback for logging performance.
Args:
batch_size: Total batch size.
log_steps: Interval of steps between logging of batch level stats.
initial_step: Optional, initial step.
logdir: Optional directory to write TensorBoard summaries.
"""
# TODO(wcromar): remove this parameter and rely on `logs` parameter of
# on_train_batch_end()
self
.
batch_size
=
batch_size
super
(
TimeHistory
,
self
).
__init__
()
self
.
log_steps
=
log_steps
self
.
last_log_step
=
initial_step
self
.
steps_before_epoch
=
initial_step
self
.
steps_in_epoch
=
0
self
.
start_time
=
None
global_batch_size_gauge
.
get_cell
().
set
(
batch_size
)
if
logdir
:
self
.
summary_writer
=
tf
.
summary
.
create_file_writer
(
logdir
)
else
:
self
.
summary_writer
=
None
# Logs start of step 1 then end of each step based on log_steps interval.
self
.
timestamp_log
=
[]
# Records the time each epoch takes to run from start to finish of epoch.
self
.
epoch_runtime_log
=
[]
@
property
def
global_steps
(
self
):
"""The current 1-indexed global step."""
return
self
.
steps_before_epoch
+
self
.
steps_in_epoch
@
property
def
average_steps_per_second
(
self
):
"""The average training steps per second across all epochs."""
return
self
.
global_steps
/
sum
(
self
.
epoch_runtime_log
)
@
property
def
average_examples_per_second
(
self
):
"""The average number of training examples per second across all epochs."""
return
self
.
average_steps_per_second
*
self
.
batch_size
def
get_examples_per_sec
(
self
,
warmup
=
1
):
"""Calculates examples/sec through timestamp_log and skip warmup period."""
# First entry in timestamp_log is the start of the step 1. The rest of the
# entries are the end of each step recorded.
time_log
=
self
.
timestamp_log
seconds
=
time_log
[
-
1
].
timestamp
-
time_log
[
warmup
].
timestamp
steps
=
time_log
[
-
1
].
batch_index
-
time_log
[
warmup
].
batch_index
return
self
.
batch_size
*
steps
/
seconds
def
get_startup_time
(
self
,
start_time_sec
):
return
self
.
timestamp_log
[
0
].
timestamp
-
start_time_sec
def
on_train_end
(
self
,
logs
=
None
):
self
.
train_finish_time
=
time
.
time
()
if
self
.
summary_writer
:
self
.
summary_writer
.
flush
()
def
on_epoch_begin
(
self
,
epoch
,
logs
=
None
):
self
.
epoch_start
=
time
.
time
()
def
on_batch_begin
(
self
,
batch
,
logs
=
None
):
if
not
self
.
start_time
:
self
.
start_time
=
time
.
time
()
if
not
first_batch_start_time
.
value
():
first_batch_start_time
.
set
(
int
(
self
.
start_time
*
1000000
))
# Record the timestamp of the first global step
if
not
self
.
timestamp_log
:
self
.
timestamp_log
.
append
(
BatchTimestamp
(
self
.
global_steps
,
self
.
start_time
))
def
on_batch_end
(
self
,
batch
,
logs
=
None
):
"""Records elapse time of the batch and calculates examples per second."""
if
not
first_batch_end_time
.
value
():
first_batch_end_time
.
set
(
int
(
time
.
time
()
*
1000000
))
self
.
steps_in_epoch
=
batch
+
1
steps_since_last_log
=
self
.
global_steps
-
self
.
last_log_step
if
steps_since_last_log
>=
self
.
log_steps
:
now
=
time
.
time
()
elapsed_time
=
now
-
self
.
start_time
steps_per_second
=
steps_since_last_log
/
elapsed_time
examples_per_second
=
steps_per_second
*
self
.
batch_size
self
.
timestamp_log
.
append
(
BatchTimestamp
(
self
.
global_steps
,
now
))
logging
.
info
(
'TimeHistory: %.2f seconds, %.2f examples/second between steps %d '
'and %d'
,
elapsed_time
,
examples_per_second
,
self
.
last_log_step
,
self
.
global_steps
)
if
self
.
summary_writer
:
with
self
.
summary_writer
.
as_default
():
tf
.
summary
.
scalar
(
'steps_per_second'
,
steps_per_second
,
self
.
global_steps
)
tf
.
summary
.
scalar
(
'examples_per_second'
,
examples_per_second
,
self
.
global_steps
)
self
.
last_log_step
=
self
.
global_steps
self
.
start_time
=
None
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
epoch_run_time
=
time
.
time
()
-
self
.
epoch_start
self
.
epoch_runtime_log
.
append
(
epoch_run_time
)
self
.
steps_before_epoch
+=
self
.
steps_in_epoch
self
.
steps_in_epoch
=
0
class
SimpleCheckpoint
(
tf
.
keras
.
callbacks
.
Callback
):
"""Keras callback to save tf.train.Checkpoints."""
def
__init__
(
self
,
checkpoint_manager
):
super
(
SimpleCheckpoint
,
self
).
__init__
()
self
.
checkpoint_manager
=
checkpoint_manager
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
step_counter
=
self
.
checkpoint_manager
.
_step_counter
.
numpy
()
# pylint: disable=protected-access
self
.
checkpoint_manager
.
save
(
checkpoint_number
=
step_counter
)
def
set_session_config
(
enable_xla
=
False
):
"""Sets the session config."""
if
enable_xla
:
tf
.
config
.
optimizer
.
set_jit
(
True
)
# TODO(hongkuny): remove set_config_v2 globally.
set_config_v2
=
set_session_config
def
set_gpu_thread_mode_and_count
(
gpu_thread_mode
,
datasets_num_private_threads
,
num_gpus
,
per_gpu_thread_count
):
"""Set GPU thread mode and count, and adjust dataset threads count."""
cpu_count
=
multiprocessing
.
cpu_count
()
logging
.
info
(
'Logical CPU cores: %s'
,
cpu_count
)
# Allocate private thread pool for each GPU to schedule and launch kernels
per_gpu_thread_count
=
per_gpu_thread_count
or
2
os
.
environ
[
'TF_GPU_THREAD_MODE'
]
=
gpu_thread_mode
os
.
environ
[
'TF_GPU_THREAD_COUNT'
]
=
str
(
per_gpu_thread_count
)
logging
.
info
(
'TF_GPU_THREAD_COUNT: %s'
,
os
.
environ
[
'TF_GPU_THREAD_COUNT'
])
logging
.
info
(
'TF_GPU_THREAD_MODE: %s'
,
os
.
environ
[
'TF_GPU_THREAD_MODE'
])
# Limit data preprocessing threadpool to CPU cores minus number of total GPU
# private threads and memory copy threads.
total_gpu_thread_count
=
per_gpu_thread_count
*
num_gpus
num_runtime_threads
=
num_gpus
if
not
datasets_num_private_threads
:
datasets_num_private_threads
=
min
(
cpu_count
-
total_gpu_thread_count
-
num_runtime_threads
,
num_gpus
*
8
)
logging
.
info
(
'Set datasets_num_private_threads to %s'
,
datasets_num_private_threads
)
official/utils/misc/model_helpers.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Miscellaneous functions that can be called by models."""
import
numbers
from
absl
import
logging
import
tensorflow
as
tf
from
tensorflow.python.util
import
nest
# pylint:disable=logging-format-interpolation
def
past_stop_threshold
(
stop_threshold
,
eval_metric
):
"""Return a boolean representing whether a model should be stopped.
Args:
stop_threshold: float, the threshold above which a model should stop
training.
eval_metric: float, the current value of the relevant metric to check.
Returns:
True if training should stop, False otherwise.
Raises:
ValueError: if either stop_threshold or eval_metric is not a number
"""
if
stop_threshold
is
None
:
return
False
if
not
isinstance
(
stop_threshold
,
numbers
.
Number
):
raise
ValueError
(
"Threshold for checking stop conditions must be a number."
)
if
not
isinstance
(
eval_metric
,
numbers
.
Number
):
raise
ValueError
(
"Eval metric being checked against stop conditions "
"must be a number."
)
if
eval_metric
>=
stop_threshold
:
logging
.
info
(
"Stop threshold of {} was passed with metric value {}."
.
format
(
stop_threshold
,
eval_metric
))
return
True
return
False
def
generate_synthetic_data
(
input_shape
,
input_value
=
0
,
input_dtype
=
None
,
label_shape
=
None
,
label_value
=
0
,
label_dtype
=
None
):
"""Create a repeating dataset with constant values.
Args:
input_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of
the input data.
input_value: Value of each input element.
input_dtype: Input dtype. If None, will be inferred by the input value.
label_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of
the label data.
label_value: Value of each input element.
label_dtype: Input dtype. If None, will be inferred by the target value.
Returns:
Dataset of tensors or tuples of tensors (if label_shape is set).
"""
# TODO(kathywu): Replace with SyntheticDataset once it is in contrib.
element
=
input_element
=
nest
.
map_structure
(
lambda
s
:
tf
.
constant
(
input_value
,
input_dtype
,
s
),
input_shape
)
if
label_shape
:
label_element
=
nest
.
map_structure
(
lambda
s
:
tf
.
constant
(
label_value
,
label_dtype
,
s
),
label_shape
)
element
=
(
input_element
,
label_element
)
return
tf
.
data
.
Dataset
.
from_tensors
(
element
).
repeat
()
def
apply_clean
(
flags_obj
):
if
flags_obj
.
clean
and
tf
.
io
.
gfile
.
exists
(
flags_obj
.
model_dir
):
logging
.
info
(
"--clean flag set. Removing existing model dir:"
" {}"
.
format
(
flags_obj
.
model_dir
))
tf
.
io
.
gfile
.
rmtree
(
flags_obj
.
model_dir
)
official/utils/misc/model_helpers_test.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Model Helper functions."""
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.utils.misc
import
model_helpers
class
PastStopThresholdTest
(
tf
.
test
.
TestCase
):
"""Tests for past_stop_threshold."""
def
setUp
(
self
):
super
(
PastStopThresholdTest
,
self
).
setUp
()
tf
.
compat
.
v1
.
disable_eager_execution
()
def
test_past_stop_threshold
(
self
):
"""Tests for normal operating conditions."""
self
.
assertTrue
(
model_helpers
.
past_stop_threshold
(
0.54
,
1
))
self
.
assertTrue
(
model_helpers
.
past_stop_threshold
(
54
,
100
))
self
.
assertFalse
(
model_helpers
.
past_stop_threshold
(
0.54
,
0.1
))
self
.
assertFalse
(
model_helpers
.
past_stop_threshold
(
-
0.54
,
-
1.5
))
self
.
assertTrue
(
model_helpers
.
past_stop_threshold
(
-
0.54
,
0
))
self
.
assertTrue
(
model_helpers
.
past_stop_threshold
(
0
,
0
))
self
.
assertTrue
(
model_helpers
.
past_stop_threshold
(
0.54
,
0.54
))
def
test_past_stop_threshold_none_false
(
self
):
"""Tests that check None returns false."""
self
.
assertFalse
(
model_helpers
.
past_stop_threshold
(
None
,
-
1.5
))
self
.
assertFalse
(
model_helpers
.
past_stop_threshold
(
None
,
None
))
self
.
assertFalse
(
model_helpers
.
past_stop_threshold
(
None
,
1.5
))
# Zero should be okay, though.
self
.
assertTrue
(
model_helpers
.
past_stop_threshold
(
0
,
1.5
))
def
test_past_stop_threshold_not_number
(
self
):
"""Tests for error conditions."""
with
self
.
assertRaises
(
ValueError
):
model_helpers
.
past_stop_threshold
(
'str'
,
1
)
with
self
.
assertRaises
(
ValueError
):
model_helpers
.
past_stop_threshold
(
'str'
,
tf
.
constant
(
5
))
with
self
.
assertRaises
(
ValueError
):
model_helpers
.
past_stop_threshold
(
'str'
,
'another'
)
with
self
.
assertRaises
(
ValueError
):
model_helpers
.
past_stop_threshold
(
0
,
None
)
with
self
.
assertRaises
(
ValueError
):
model_helpers
.
past_stop_threshold
(
0.7
,
'str'
)
with
self
.
assertRaises
(
ValueError
):
model_helpers
.
past_stop_threshold
(
tf
.
constant
(
4
),
None
)
class
SyntheticDataTest
(
tf
.
test
.
TestCase
):
"""Tests for generate_synthetic_data."""
def
test_generate_synethetic_data
(
self
):
input_element
,
label_element
=
tf
.
compat
.
v1
.
data
.
make_one_shot_iterator
(
model_helpers
.
generate_synthetic_data
(
input_shape
=
tf
.
TensorShape
([
5
]),
input_value
=
123
,
input_dtype
=
tf
.
float32
,
label_shape
=
tf
.
TensorShape
([]),
label_value
=
456
,
label_dtype
=
tf
.
int32
)).
get_next
()
with
self
.
session
()
as
sess
:
for
n
in
range
(
5
):
inp
,
lab
=
sess
.
run
((
input_element
,
label_element
))
self
.
assertAllClose
(
inp
,
[
123.
,
123.
,
123.
,
123.
,
123.
])
self
.
assertEquals
(
lab
,
456
)
def
test_generate_only_input_data
(
self
):
d
=
model_helpers
.
generate_synthetic_data
(
input_shape
=
tf
.
TensorShape
([
4
]),
input_value
=
43.5
,
input_dtype
=
tf
.
float32
)
element
=
tf
.
compat
.
v1
.
data
.
make_one_shot_iterator
(
d
).
get_next
()
self
.
assertFalse
(
isinstance
(
element
,
tuple
))
with
self
.
session
()
as
sess
:
inp
=
sess
.
run
(
element
)
self
.
assertAllClose
(
inp
,
[
43.5
,
43.5
,
43.5
,
43.5
])
def
test_generate_nested_data
(
self
):
d
=
model_helpers
.
generate_synthetic_data
(
input_shape
=
{
'a'
:
tf
.
TensorShape
([
2
]),
'b'
:
{
'c'
:
tf
.
TensorShape
([
3
]),
'd'
:
tf
.
TensorShape
([])
}
},
input_value
=
1.1
)
element
=
tf
.
compat
.
v1
.
data
.
make_one_shot_iterator
(
d
).
get_next
()
self
.
assertIn
(
'a'
,
element
)
self
.
assertIn
(
'b'
,
element
)
self
.
assertEquals
(
len
(
element
[
'b'
]),
2
)
self
.
assertIn
(
'c'
,
element
[
'b'
])
self
.
assertIn
(
'd'
,
element
[
'b'
])
self
.
assertNotIn
(
'c'
,
element
)
with
self
.
session
()
as
sess
:
inp
=
sess
.
run
(
element
)
self
.
assertAllClose
(
inp
[
'a'
],
[
1.1
,
1.1
])
self
.
assertAllClose
(
inp
[
'b'
][
'c'
],
[
1.1
,
1.1
,
1.1
])
self
.
assertAllClose
(
inp
[
'b'
][
'd'
],
1.1
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/image_classification/resnet/__init__.py
0 → 100644
View file @
3d61d6b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/vision/image_classification/resnet/__pycache__/__init__.cpython-37.pyc
0 → 100644
View file @
3d61d6b3
File added
official/vision/image_classification/resnet/__pycache__/common.cpython-37.pyc
0 → 100644
View file @
3d61d6b3
File added
official/vision/image_classification/resnet/__pycache__/imagenet_preprocessing.cpython-37.pyc
0 → 100644
View file @
3d61d6b3
File added
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment