Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
310f70d5
Unverified
Commit
310f70d5
authored
Apr 10, 2018
by
Karmel Allison
Committed by
GitHub
Apr 10, 2018
Browse files
Adding stop threshold logic (#3863)
* Adding tests * Adding tests * Repackaging * Adding logging * Linting
parent
aad56e4c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
167 additions
and
4 deletions
+167
-4
official/mnist/mnist.py
official/mnist/mnist.py
+5
-0
official/mnist/mnist_eager.py
official/mnist/mnist_eager.py
+1
-2
official/resnet/imagenet_test.py
official/resnet/imagenet_test.py
+1
-0
official/resnet/resnet_run_loop.py
official/resnet/resnet_run_loop.py
+5
-0
official/utils/arg_parsers/parsers.py
official/utils/arg_parsers/parsers.py
+26
-2
official/utils/misc/__init__.py
official/utils/misc/__init__.py
+0
-0
official/utils/misc/model_helpers.py
official/utils/misc/model_helpers.py
+55
-0
official/utils/misc/model_helpers_test.py
official/utils/misc/model_helpers_test.py
+69
-0
official/wide_deep/wide_deep.py
official/wide_deep/wide_deep.py
+5
-0
No files found.
official/mnist/mnist.py
View file @
310f70d5
...
@@ -25,6 +25,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
...
@@ -25,6 +25,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from
official.mnist
import
dataset
from
official.mnist
import
dataset
from
official.utils.arg_parsers
import
parsers
from
official.utils.arg_parsers
import
parsers
from
official.utils.logs
import
hooks_helper
from
official.utils.logs
import
hooks_helper
from
official.utils.misc
import
model_helpers
LEARNING_RATE
=
1e-4
LEARNING_RATE
=
1e-4
...
@@ -231,6 +232,10 @@ def main(argv):
...
@@ -231,6 +232,10 @@ def main(argv):
eval_results
=
mnist_classifier
.
evaluate
(
input_fn
=
eval_input_fn
)
eval_results
=
mnist_classifier
.
evaluate
(
input_fn
=
eval_input_fn
)
print
(
'
\n
Evaluation results:
\n\t
%s
\n
'
%
eval_results
)
print
(
'
\n
Evaluation results:
\n\t
%s
\n
'
%
eval_results
)
if
model_helpers
.
past_stop_threshold
(
flags
.
stop_threshold
,
eval_results
[
'accuracy'
]):
break
# Export the model
# Export the model
if
flags
.
export_dir
is
not
None
:
if
flags
.
export_dir
is
not
None
:
image
=
tf
.
placeholder
(
tf
.
float32
,
[
None
,
28
,
28
])
image
=
tf
.
placeholder
(
tf
.
float32
,
[
None
,
28
,
28
])
...
...
official/mnist/mnist_eager.py
View file @
310f70d5
...
@@ -164,8 +164,7 @@ class MNISTEagerArgParser(argparse.ArgumentParser):
...
@@ -164,8 +164,7 @@ class MNISTEagerArgParser(argparse.ArgumentParser):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
MNISTEagerArgParser
,
self
).
__init__
(
parents
=
[
super
(
MNISTEagerArgParser
,
self
).
__init__
(
parents
=
[
parsers
.
BaseParser
(
parsers
.
EagerParser
(),
epochs_between_evals
=
False
,
multi_gpu
=
False
,
hooks
=
False
),
parsers
.
ImageModelParser
()])
parsers
.
ImageModelParser
()])
self
.
add_argument
(
self
.
add_argument
(
...
...
official/resnet/imagenet_test.py
View file @
310f70d5
...
@@ -318,5 +318,6 @@ class BaseTest(tf.test.TestCase):
...
@@ -318,5 +318,6 @@ class BaseTest(tf.test.TestCase):
extra_flags
=
[
'-v'
,
'2'
,
'-rs'
,
'200'
]
extra_flags
=
[
'-v'
,
'2'
,
'-rs'
,
'200'
]
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/resnet/resnet_run_loop.py
View file @
310f70d5
...
@@ -33,6 +33,7 @@ from official.utils.arg_parsers import parsers
...
@@ -33,6 +33,7 @@ from official.utils.arg_parsers import parsers
from
official.utils.export
import
export
from
official.utils.export
import
export
from
official.utils.logs
import
hooks_helper
from
official.utils.logs
import
hooks_helper
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
from
official.utils.misc
import
model_helpers
################################################################################
################################################################################
...
@@ -438,6 +439,10 @@ def resnet_main(flags, model_function, input_function, shape=None):
...
@@ -438,6 +439,10 @@ def resnet_main(flags, model_function, input_function, shape=None):
if
benchmark_logger
:
if
benchmark_logger
:
benchmark_logger
.
log_estimator_evaluation_result
(
eval_results
)
benchmark_logger
.
log_estimator_evaluation_result
(
eval_results
)
if
model_helpers
.
past_stop_threshold
(
flags
.
stop_threshold
,
eval_results
[
'accuracy'
]):
break
if
flags
.
export_dir
is
not
None
:
if
flags
.
export_dir
is
not
None
:
warn_on_multi_gpu_export
(
flags
.
multi_gpu
)
warn_on_multi_gpu_export
(
flags
.
multi_gpu
)
...
...
official/utils/arg_parsers/parsers.py
View file @
310f70d5
...
@@ -99,14 +99,17 @@ class BaseParser(argparse.ArgumentParser):
...
@@ -99,14 +99,17 @@ class BaseParser(argparse.ArgumentParser):
model_dir: Create a flag for specifying the model file directory.
model_dir: Create a flag for specifying the model file directory.
train_epochs: Create a flag to specify the number of training epochs.
train_epochs: Create a flag to specify the number of training epochs.
epochs_between_evals: Create a flag to specify the frequency of testing.
epochs_between_evals: Create a flag to specify the frequency of testing.
stop_threshold: Create a flag to specify a threshold accuracy or other
eval metric which should trigger the end of training.
batch_size: Create a flag to specify the batch size.
batch_size: Create a flag to specify the batch size.
multi_gpu: Create a flag to allow the use of all available GPUs.
multi_gpu: Create a flag to allow the use of all available GPUs.
hooks: Create a flag to specify hooks for logging.
hooks: Create a flag to specify hooks for logging.
"""
"""
def
__init__
(
self
,
add_help
=
False
,
data_dir
=
True
,
model_dir
=
True
,
def
__init__
(
self
,
add_help
=
False
,
data_dir
=
True
,
model_dir
=
True
,
train_epochs
=
True
,
epochs_between_evals
=
True
,
batch_size
=
True
,
train_epochs
=
True
,
epochs_between_evals
=
True
,
multi_gpu
=
True
,
hooks
=
True
):
stop_threshold
=
True
,
batch_size
=
True
,
multi_gpu
=
True
,
hooks
=
True
):
super
(
BaseParser
,
self
).
__init__
(
add_help
=
add_help
)
super
(
BaseParser
,
self
).
__init__
(
add_help
=
add_help
)
if
data_dir
:
if
data_dir
:
...
@@ -139,6 +142,15 @@ class BaseParser(argparse.ArgumentParser):
...
@@ -139,6 +142,15 @@ class BaseParser(argparse.ArgumentParser):
metavar
=
"<EBE>"
metavar
=
"<EBE>"
)
)
if
stop_threshold
:
self
.
add_argument
(
"--stop_threshold"
,
"-st"
,
type
=
float
,
default
=
None
,
help
=
"[default: %(default)s] If passed, training will stop at "
"the earlier of train_epochs and when the evaluation metric is "
"greater than or equal to stop_threshold."
,
metavar
=
"<ST>"
)
if
batch_size
:
if
batch_size
:
self
.
add_argument
(
self
.
add_argument
(
"--batch_size"
,
"-bs"
,
type
=
int
,
default
=
32
,
"--batch_size"
,
"-bs"
,
type
=
int
,
default
=
32
,
...
@@ -345,3 +357,15 @@ class BenchmarkParser(argparse.ArgumentParser):
...
@@ -345,3 +357,15 @@ class BenchmarkParser(argparse.ArgumentParser):
" benchmark metric information will be uploaded."
,
" benchmark metric information will be uploaded."
,
metavar
=
"<BMT>"
metavar
=
"<BMT>"
)
)
class
EagerParser
(
BaseParser
):
"""Remove options not relevant for Eager from the BaseParser."""
def
__init__
(
self
,
add_help
=
False
,
data_dir
=
True
,
model_dir
=
True
,
train_epochs
=
True
,
batch_size
=
True
):
super
(
EagerParser
,
self
).
__init__
(
add_help
=
add_help
,
data_dir
=
data_dir
,
model_dir
=
model_dir
,
train_epochs
=
train_epochs
,
epochs_between_evals
=
False
,
stop_threshold
=
False
,
batch_size
=
batch_size
,
multi_gpu
=
False
,
hooks
=
False
)
official/utils/misc/__init__.py
0 → 100644
View file @
310f70d5
official/utils/misc/model_helpers.py
0 → 100644
View file @
310f70d5
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Miscellaneous functions that can be called by models."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numbers
import
tensorflow
as
tf
def
past_stop_threshold
(
stop_threshold
,
eval_metric
):
"""Return a boolean representing whether a model should be stopped.
Args:
stop_threshold: float, the threshold above which a model should stop
training.
eval_metric: float, the current value of the relevant metric to check.
Returns:
True if training should stop, False otherwise.
Raises:
ValueError: if either stop_threshold or eval_metric is not a number
"""
if
stop_threshold
is
None
:
return
False
if
not
isinstance
(
stop_threshold
,
numbers
.
Number
):
raise
ValueError
(
"Threshold for checking stop conditions must be a number."
)
if
not
isinstance
(
eval_metric
,
numbers
.
Number
):
raise
ValueError
(
"Eval metric being checked against stop conditions "
"must be a number."
)
if
eval_metric
>=
stop_threshold
:
tf
.
logging
.
info
(
"Stop threshold of {} was passed with metric value {}."
.
format
(
stop_threshold
,
eval_metric
))
return
True
return
False
official/utils/misc/model_helpers_test.py
0 → 100644
View file @
310f70d5
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Tests for Model Helper functions."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
official.utils.misc
import
model_helpers
class
PastStopThresholdTest
(
tf
.
test
.
TestCase
):
"""Tests for past_stop_threshold."""
def
test_past_stop_threshold
(
self
):
"""Tests for normal operating conditions."""
self
.
assertTrue
(
model_helpers
.
past_stop_threshold
(
0.54
,
1
))
self
.
assertTrue
(
model_helpers
.
past_stop_threshold
(
54
,
100
))
self
.
assertFalse
(
model_helpers
.
past_stop_threshold
(
0.54
,
0.1
))
self
.
assertFalse
(
model_helpers
.
past_stop_threshold
(
-
0.54
,
-
1.5
))
self
.
assertTrue
(
model_helpers
.
past_stop_threshold
(
-
0.54
,
0
))
self
.
assertTrue
(
model_helpers
.
past_stop_threshold
(
0
,
0
))
self
.
assertTrue
(
model_helpers
.
past_stop_threshold
(
0.54
,
0.54
))
def
test_past_stop_threshold_none_false
(
self
):
"""Tests that check None returns false."""
self
.
assertFalse
(
model_helpers
.
past_stop_threshold
(
None
,
-
1.5
))
self
.
assertFalse
(
model_helpers
.
past_stop_threshold
(
None
,
None
))
self
.
assertFalse
(
model_helpers
.
past_stop_threshold
(
None
,
1.5
))
# Zero should be okay, though.
self
.
assertTrue
(
model_helpers
.
past_stop_threshold
(
0
,
1.5
))
def
test_past_stop_threshold_not_number
(
self
):
"""Tests for error conditions."""
with
self
.
assertRaises
(
ValueError
):
model_helpers
.
past_stop_threshold
(
"str"
,
1
)
with
self
.
assertRaises
(
ValueError
):
model_helpers
.
past_stop_threshold
(
"str"
,
tf
.
constant
(
5
))
with
self
.
assertRaises
(
ValueError
):
model_helpers
.
past_stop_threshold
(
"str"
,
"another"
)
with
self
.
assertRaises
(
ValueError
):
model_helpers
.
past_stop_threshold
(
0
,
None
)
with
self
.
assertRaises
(
ValueError
):
model_helpers
.
past_stop_threshold
(
0.7
,
"str"
)
with
self
.
assertRaises
(
ValueError
):
model_helpers
.
past_stop_threshold
(
tf
.
constant
(
4
),
None
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/wide_deep/wide_deep.py
View file @
310f70d5
...
@@ -26,6 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
...
@@ -26,6 +26,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from
official.utils.arg_parsers
import
parsers
from
official.utils.arg_parsers
import
parsers
from
official.utils.logs
import
hooks_helper
from
official.utils.logs
import
hooks_helper
from
official.utils.misc
import
model_helpers
_CSV_COLUMNS
=
[
_CSV_COLUMNS
=
[
'age'
,
'workclass'
,
'fnlwgt'
,
'education'
,
'education_num'
,
'age'
,
'workclass'
,
'fnlwgt'
,
'education'
,
'education_num'
,
...
@@ -211,6 +212,10 @@ def main(argv):
...
@@ -211,6 +212,10 @@ def main(argv):
for
key
in
sorted
(
results
):
for
key
in
sorted
(
results
):
print
(
'%s: %s'
%
(
key
,
results
[
key
]))
print
(
'%s: %s'
%
(
key
,
results
[
key
]))
if
model_helpers
.
past_stop_threshold
(
flags
.
stop_threshold
,
results
[
'accuracy'
]):
break
class
WideDeepArgParser
(
argparse
.
ArgumentParser
):
class
WideDeepArgParser
(
argparse
.
ArgumentParser
):
"""Argument parser for running the wide deep model."""
"""Argument parser for running the wide deep model."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment