Unverified Commit 086d9148 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Unified arg parser (#3574)

Create groups of arg parsers and convert the official resnet model to
the new arg parsers.
parent 86b1f07b
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
\ No newline at end of file
......@@ -36,6 +36,8 @@ import os
import tensorflow as tf
from official.utils.arg_parsers import parsers # pylint: disable=g-bad-import-order
_BATCH_NORM_DECAY = 0.997
_BATCH_NORM_EPSILON = 1e-5
......@@ -779,71 +781,15 @@ class ResnetArgParser(argparse.ArgumentParser):
"""
def __init__(self, resnet_size_choices=None):
super(ResnetArgParser, self).__init__()
self.add_argument(
'--data_dir', type=str, default='/tmp/resnet_data',
help='The directory where the input data is stored.')
self.add_argument(
'--num_parallel_calls', type=int, default=5,
help='The number of records that are processed in parallel '
'during input processing. This can be optimized per data set but '
'for generally homogeneous data sets, should be approximately the '
'number of available CPU cores.')
self.add_argument(
'--model_dir', type=str, default='/tmp/resnet_model',
help='The directory where the model will be stored.')
super(ResnetArgParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.PerformanceParser(),
parsers.ImageModelParser(),
])
self.add_argument(
'--resnet_size', type=int, default=50,
'--resnet_size', '-rs', type=int, default=50,
choices=resnet_size_choices,
help='The size of the ResNet model to use.')
self.add_argument(
'--train_epochs', type=int, default=100,
help='The number of epochs to use for training.')
self.add_argument(
'--epochs_per_eval', type=int, default=1,
help='The number of training epochs to run between evaluations.')
self.add_argument(
'--batch_size', type=int, default=32,
help='Batch size for training and evaluation.')
self.add_argument(
'--data_format', type=str, default=None,
choices=['channels_first', 'channels_last'],
help='A flag to override the data format used in the model. '
'channels_first provides a performance boost on GPU but '
'is not always compatible with CPU. If left unspecified, '
'the data format will be chosen automatically based on '
'whether TensorFlow was built for CPU or GPU.')
self.add_argument(
'--multi_gpu', action='store_true',
help='If set, run across all available GPUs.')
self.add_argument(
'-v', '--version', type=int, choices=[1, 2], dest="version",
default=DEFAULT_VERSION,
help="Version of ResNet. (1 or 2) See README.md for details."
help='[default: %(default)s]The size of the ResNet model to use.',
metavar='<RS>'
)
# Advanced args
self.add_argument(
'--use_synthetic_data', action='store_true',
help='If set, use fake data (zeroes) instead of a real dataset. '
'This mode is useful for performance debugging, as it removes '
'input processing steps, but will not learn anything.')
self.add_argument(
'--inter_op_parallelism_threads', type=int, default=0,
help='Number of inter_op_parallelism_threads to use for CPU. '
'See TensorFlow config.proto for details.')
self.add_argument(
'--intra_op_parallelism_threads', type=int, default=0,
help='Number of intra_op_parallelism_threads to use for CPU. '
'See TensorFlow config.proto for details.')
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
\ No newline at end of file
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Collection of parsers which are shared among the official models.
The parsers in this module are intended to be used as parents to all arg
parsers in official models. For instance, one might define a new class:
class ExampleParser(argparse.ArgumentParser):
def __init__(self):
super(ExampleParser, self).__init__(parents=[
official.utils.arg_parsers.LocationParser(data_dir=True, model_dir=True),
official.utils.arg_parsers.DummyParser(use_synthetic_data=True),
])
self.add_argument(
"--application_specific_arg", "-asa", type=int, default=123,
help="[default: %(default)s] This arg is application specific.",
metavar="<ASA>"
)
Notes about add_argument():
Argparse will automatically template in default values in help messages if
the "%(default)s" string appears in the message. Using the example above:
parser = ExampleParser()
parser.set_defaults(application_specific_arg=3141592)
parser.parse_args(["-h"])
When the help text is generated, it will display 3141592 to the user. (Even
though the default was 123 when the flag was created.)
The metavar variable determines how the flag will appear in help text. If
not specified, the convention is to use name.upper(). Thus rather than:
--application_specific_arg APPLICATION_SPECIFIC_ARG, -asa APPLICATION_SPECIFIC_ARG
if metavar="<ASA>" is set, the user sees:
--application_specific_arg <ASA>, -asa <ASA>
"""
import argparse
class BaseParser(argparse.ArgumentParser):
"""Parser to contain flags which will be nearly universal across models.
Args:
add_help: Create the "--help" flag. False if class instance is a parent.
data_dir: Create a flag for specifying the input data directory.
model_dir: Create a flag for specifying the model file directory.
train_epochs: Create a flag to specify the number of training epochs.
epochs_per_eval: Create a flag to specify the frequency of testing.
batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs.
"""
def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_per_eval=True, batch_size=True,
multi_gpu=True):
super(BaseParser, self).__init__(add_help=add_help)
if data_dir:
self.add_argument(
"--data_dir", "-dd", default="/tmp",
help="[default: %(default)s] The location of the input data.",
metavar="<DD>",
)
if model_dir:
self.add_argument(
"--model_dir", "-md", default="/tmp",
help="[default: %(default)s] The location of the model files.",
metavar="<MD>",
)
if train_epochs:
self.add_argument(
"--train_epochs", "-te", type=int, default=1,
help="[default: %(default)s] The number of epochs used to train.",
metavar="<TE>"
)
if epochs_per_eval:
self.add_argument(
"--epochs_per_eval", "-epe", type=int, default=1,
help="[default: %(default)s] The number of training epochs to run "
"between evaluations.",
metavar="<EPE>"
)
if batch_size:
self.add_argument(
"--batch_size", "-bs", type=int, default=32,
help="[default: %(default)s] Batch size for training and evaluation.",
metavar="<BS>"
)
if multi_gpu:
self.add_argument(
"--multi_gpu", action="store_true",
help="If set, run across all available GPUs."
)
class PerformanceParser(argparse.ArgumentParser):
"""Default parser for specifying performance tuning arguments.
Args:
add_help: Create the "--help" flag. False if class instance is a parent.
num_parallel_calls: Create a flag to specify parallelism of data loading.
inter_op: Create a flag to allow specification of inter op threads.
intra_op: Create a flag to allow specification of intra op threads.
"""
def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True):
super(PerformanceParser, self).__init__(add_help=add_help)
if num_parallel_calls:
self.add_argument(
"--num_parallel_calls", "-npc",
type=int, default=5,
help="[default: %(default)s] The number of records that are "
"processed in parallel during input processing. This can be "
"optimized per data set but for generally homogeneous data "
"sets, should be approximately the number of available CPU "
"cores.",
metavar="<NPC>"
)
if inter_op:
self.add_argument(
"--inter_op_parallelism_threads", "-inter",
type=int, default=0,
help="[default: %(default)s Number of inter_op_parallelism_threads "
"to use for CPU. See TensorFlow config.proto for details.",
metavar="<INTER>"
)
if intra_op:
self.add_argument(
"--intra_op_parallelism_threads", "-intra",
type=int, default=0,
help="[default: %(default)s Number of intra_op_parallelism_threads "
"to use for CPU. See TensorFlow config.proto for details.",
metavar="<INTRA>"
)
if use_synthetic_data:
self.add_argument(
"--use_synthetic_data", "-synth",
action="store_true",
help="If set, use fake data (zeroes) instead of a real dataset. "
"This mode is useful for performance debugging, as it removes "
"input processing steps, but will not learn anything."
)
class ImageModelParser(argparse.ArgumentParser):
"""Default parser for specification image specific behavior.
Args:
add_help: Create the "--help" flag. False if class instance is a parent.
data_format: Create a flag to specify image axis convention.
"""
def __init__(self, add_help=False, data_format=True):
super(ImageModelParser, self).__init__(add_help=add_help)
if data_format:
self.add_argument(
"--data_format", "-df",
help="A flag to override the data format used in the model. "
"channels_first provides a performance boost on GPU but is not "
"always compatible with CPU. If left unspecified, the data "
"format will be chosen automatically based on whether TensorFlow"
"was built for CPU or GPU.",
metavar="<CF>",
)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import argparse
import unittest
from official.utils.arg_parsers import parsers
class TestParser(argparse.ArgumentParser):
"""Class to test canned parser functionality."""
def __init__(self):
super(TestParser, self).__init__(parents=[
parsers.BaseParser(),
parsers.PerformanceParser(num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True),
parsers.ImageModelParser(data_format=True)
])
class BaseTester(unittest.TestCase):
def test_default_setting(self):
"""Test to ensure fields exist and defaults can be set.
"""
defaults = dict(
data_dir="dfgasf",
model_dir="dfsdkjgbs",
train_epochs=534,
epochs_per_eval=15,
batch_size=256,
num_parallel_calls=18,
inter_op_parallelism_threads=5,
intra_op_parallelism_thread=10,
data_format="channels_first"
)
parser = TestParser()
parser.set_defaults(**defaults)
namespace_vars = vars(parser.parse_args([]))
for key, value in defaults.items():
assert namespace_vars[key] == value
def test_booleans(self):
"""Test to ensure boolean flags trigger as expected.
"""
parser = TestParser()
namespace = parser.parse_args(["--multi_gpu", "--use_synthetic_data"])
assert namespace.multi_gpu
assert namespace.use_synthetic_data
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment