"integration-tests/vscode:/vscode.git/clone" did not exist on "142cdabed377772b763fc8d79a131b16ed991718"
Commit 4e0ca759 authored by Asim Shankar's avatar Asim Shankar
Browse files

Merge branch 'master' into mnist_sequential_and_estimator

parents e48a403e 310f70d5
...@@ -25,6 +25,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -25,6 +25,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.mnist import dataset from official.mnist import dataset
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
from official.utils.misc import model_helpers
LEARNING_RATE = 1e-4 LEARNING_RATE = 1e-4
...@@ -212,6 +213,10 @@ def main(argv): ...@@ -212,6 +213,10 @@ def main(argv):
eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
print('\nEvaluation results:\n\t%s\n' % eval_results) print('\nEvaluation results:\n\t%s\n' % eval_results)
if model_helpers.past_stop_threshold(
flags.stop_threshold, eval_results['accuracy']):
break
# Export the model # Export the model
if flags.export_dir is not None: if flags.export_dir is not None:
image = tf.placeholder(tf.float32, [None, 28, 28]) image = tf.placeholder(tf.float32, [None, 28, 28])
......
...@@ -164,8 +164,7 @@ class MNISTEagerArgParser(argparse.ArgumentParser): ...@@ -164,8 +164,7 @@ class MNISTEagerArgParser(argparse.ArgumentParser):
def __init__(self): def __init__(self):
super(MNISTEagerArgParser, self).__init__(parents=[ super(MNISTEagerArgParser, self).__init__(parents=[
parsers.BaseParser( parsers.EagerParser(),
epochs_between_evals=False, multi_gpu=False, hooks=False),
parsers.ImageModelParser()]) parsers.ImageModelParser()])
self.add_argument( self.add_argument(
......
...@@ -318,5 +318,6 @@ class BaseTest(tf.test.TestCase): ...@@ -318,5 +318,6 @@ class BaseTest(tf.test.TestCase):
extra_flags=['-v', '2', '-rs', '200'] extra_flags=['-v', '2', '-rs', '200']
) )
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -33,6 +33,7 @@ from official.utils.arg_parsers import parsers ...@@ -33,6 +33,7 @@ from official.utils.arg_parsers import parsers
from official.utils.export import export from official.utils.export import export
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
from official.utils.logs import logger from official.utils.logs import logger
from official.utils.misc import model_helpers
################################################################################ ################################################################################
...@@ -438,6 +439,10 @@ def resnet_main(flags, model_function, input_function, shape=None): ...@@ -438,6 +439,10 @@ def resnet_main(flags, model_function, input_function, shape=None):
if benchmark_logger: if benchmark_logger:
benchmark_logger.log_estimator_evaluation_result(eval_results) benchmark_logger.log_estimator_evaluation_result(eval_results)
if model_helpers.past_stop_threshold(
flags.stop_threshold, eval_results['accuracy']):
break
if flags.export_dir is not None: if flags.export_dir is not None:
warn_on_multi_gpu_export(flags.multi_gpu) warn_on_multi_gpu_export(flags.multi_gpu)
......
...@@ -99,14 +99,17 @@ class BaseParser(argparse.ArgumentParser): ...@@ -99,14 +99,17 @@ class BaseParser(argparse.ArgumentParser):
model_dir: Create a flag for specifying the model file directory. model_dir: Create a flag for specifying the model file directory.
train_epochs: Create a flag to specify the number of training epochs. train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing. epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other
eval metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size. batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs. multi_gpu: Create a flag to allow the use of all available GPUs.
hooks: Create a flag to specify hooks for logging. hooks: Create a flag to specify hooks for logging.
""" """
def __init__(self, add_help=False, data_dir=True, model_dir=True, def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, epochs_between_evals=True, batch_size=True, train_epochs=True, epochs_between_evals=True,
multi_gpu=True, hooks=True): stop_threshold=True, batch_size=True, multi_gpu=True,
hooks=True):
super(BaseParser, self).__init__(add_help=add_help) super(BaseParser, self).__init__(add_help=add_help)
if data_dir: if data_dir:
...@@ -139,6 +142,15 @@ class BaseParser(argparse.ArgumentParser): ...@@ -139,6 +142,15 @@ class BaseParser(argparse.ArgumentParser):
metavar="<EBE>" metavar="<EBE>"
) )
if stop_threshold:
self.add_argument(
"--stop_threshold", "-st", type=float, default=None,
help="[default: %(default)s] If passed, training will stop at "
"the earlier of train_epochs and when the evaluation metric is "
"greater than or equal to stop_threshold.",
metavar="<ST>"
)
if batch_size: if batch_size:
self.add_argument( self.add_argument(
"--batch_size", "-bs", type=int, default=32, "--batch_size", "-bs", type=int, default=32,
...@@ -345,3 +357,15 @@ class BenchmarkParser(argparse.ArgumentParser): ...@@ -345,3 +357,15 @@ class BenchmarkParser(argparse.ArgumentParser):
" benchmark metric information will be uploaded.", " benchmark metric information will be uploaded.",
metavar="<BMT>" metavar="<BMT>"
) )
class EagerParser(BaseParser):
"""Remove options not relevant for Eager from the BaseParser."""
def __init__(self, add_help=False, data_dir=True, model_dir=True,
train_epochs=True, batch_size=True):
super(EagerParser, self).__init__(
add_help=add_help, data_dir=data_dir, model_dir=model_dir,
train_epochs=train_epochs, epochs_between_evals=False,
stop_threshold=False, batch_size=batch_size, multi_gpu=False,
hooks=False)
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Miscellaneous functions that can be called by models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numbers
import tensorflow as tf
def past_stop_threshold(stop_threshold, eval_metric):
"""Return a boolean representing whether a model should be stopped.
Args:
stop_threshold: float, the threshold above which a model should stop
training.
eval_metric: float, the current value of the relevant metric to check.
Returns:
True if training should stop, False otherwise.
Raises:
ValueError: if either stop_threshold or eval_metric is not a number
"""
if stop_threshold is None:
return False
if not isinstance(stop_threshold, numbers.Number):
raise ValueError("Threshold for checking stop conditions must be a number.")
if not isinstance(eval_metric, numbers.Number):
raise ValueError("Eval metric being checked against stop conditions "
"must be a number.")
if eval_metric >= stop_threshold:
tf.logging.info(
"Stop threshold of {} was passed with metric value {}.".format(
stop_threshold, eval_metric))
return True
return False
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Tests for Model Helper functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.misc import model_helpers
class PastStopThresholdTest(tf.test.TestCase):
"""Tests for past_stop_threshold."""
def test_past_stop_threshold(self):
"""Tests for normal operating conditions."""
self.assertTrue(model_helpers.past_stop_threshold(0.54, 1))
self.assertTrue(model_helpers.past_stop_threshold(54, 100))
self.assertFalse(model_helpers.past_stop_threshold(0.54, 0.1))
self.assertFalse(model_helpers.past_stop_threshold(-0.54, -1.5))
self.assertTrue(model_helpers.past_stop_threshold(-0.54, 0))
self.assertTrue(model_helpers.past_stop_threshold(0, 0))
self.assertTrue(model_helpers.past_stop_threshold(0.54, 0.54))
def test_past_stop_threshold_none_false(self):
"""Tests that check None returns false."""
self.assertFalse(model_helpers.past_stop_threshold(None, -1.5))
self.assertFalse(model_helpers.past_stop_threshold(None, None))
self.assertFalse(model_helpers.past_stop_threshold(None, 1.5))
# Zero should be okay, though.
self.assertTrue(model_helpers.past_stop_threshold(0, 1.5))
def test_past_stop_threshold_not_number(self):
"""Tests for error conditions."""
with self.assertRaises(ValueError):
model_helpers.past_stop_threshold("str", 1)
with self.assertRaises(ValueError):
model_helpers.past_stop_threshold("str", tf.constant(5))
with self.assertRaises(ValueError):
model_helpers.past_stop_threshold("str", "another")
with self.assertRaises(ValueError):
model_helpers.past_stop_threshold(0, None)
with self.assertRaises(ValueError):
model_helpers.past_stop_threshold(0.7, "str")
with self.assertRaises(ValueError):
model_helpers.past_stop_threshold(tf.constant(4), None)
if __name__ == "__main__":
tf.test.main()
...@@ -26,6 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order ...@@ -26,6 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from official.utils.arg_parsers import parsers from official.utils.arg_parsers import parsers
from official.utils.logs import hooks_helper from official.utils.logs import hooks_helper
from official.utils.misc import model_helpers
_CSV_COLUMNS = [ _CSV_COLUMNS = [
'age', 'workclass', 'fnlwgt', 'education', 'education_num', 'age', 'workclass', 'fnlwgt', 'education', 'education_num',
...@@ -211,6 +212,10 @@ def main(argv): ...@@ -211,6 +212,10 @@ def main(argv):
for key in sorted(results): for key in sorted(results):
print('%s: %s' % (key, results[key])) print('%s: %s' % (key, results[key]))
if model_helpers.past_stop_threshold(
flags.stop_threshold, results['accuracy']):
break
class WideDeepArgParser(argparse.ArgumentParser): class WideDeepArgParser(argparse.ArgumentParser):
"""Argument parser for running the wide deep model.""" """Argument parser for running the wide deep model."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment