Unverified Commit dcca6b44 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Basic end to end test of resnet. (#3598)

This commit adds a basic end to end test for resnet cifar10 and imagenet models to check for syntax errors outside of the core neural net code. 
parent f028970f
...@@ -209,14 +209,7 @@ def cifar10_model_fn(features, labels, mode, params): ...@@ -209,14 +209,7 @@ def cifar10_model_fn(features, labels, mode, params):
multi_gpu=params['multi_gpu']) multi_gpu=params['multi_gpu'])
def main(unused_argv): def main(argv):
input_function = FLAGS.use_synthetic_data and get_synth_input_fn() or input_fn
resnet_run_loop.resnet_main(FLAGS, cifar10_model_fn, input_function)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
parser = resnet_run_loop.ResnetArgParser() parser = resnet_run_loop.ResnetArgParser()
# Set defaults that are reasonable for this model. # Set defaults that are reasonable for this model.
parser.set_defaults(data_dir='/tmp/cifar10_data', parser.set_defaults(data_dir='/tmp/cifar10_data',
...@@ -226,5 +219,12 @@ if __name__ == '__main__': ...@@ -226,5 +219,12 @@ if __name__ == '__main__':
epochs_per_eval=10, epochs_per_eval=10,
batch_size=128) batch_size=128)
FLAGS, unparsed = parser.parse_known_args() flags = parser.parse_args(args=argv[1:])
tf.app.run(argv=[sys.argv[0]] + unparsed)
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn
resnet_run_loop.resnet_main(flags, cifar10_model_fn, input_function)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run(argv=sys.argv)
...@@ -23,6 +23,7 @@ import numpy as np ...@@ -23,6 +23,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
from official.resnet import cifar10_main from official.resnet import cifar10_main
from official.utils.testing import integration
tf.logging.set_verbosity(tf.logging.ERROR) tf.logging.set_verbosity(tf.logging.ERROR)
...@@ -135,6 +136,12 @@ class BaseTest(tf.test.TestCase): ...@@ -135,6 +136,12 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(output.shape, (batch_size, num_classes)) self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic(main=cifar10_main.main, extra_flags=['-v', '1'])
def test_cifar10_end_to_end_synthetic_v2(self):
integration.run_synthetic(main=cifar10_main.main, extra_flags=['-v', '2'])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -281,15 +281,15 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -281,15 +281,15 @@ def imagenet_model_fn(features, labels, mode, params):
multi_gpu=params['multi_gpu']) multi_gpu=params['multi_gpu'])
def main(unused_argv): def main(argv):
input_function = FLAGS.use_synthetic_data and get_synth_input_fn() or input_fn parser = resnet_run_loop.ResnetArgParser(
resnet_run_loop.resnet_main(FLAGS, imagenet_model_fn, input_function) resnet_size_choices=[18, 34, 50, 101, 152, 200])
flags = parser.parse_args(args=argv[1:])
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn
resnet_run_loop.resnet_main(flags, imagenet_model_fn, input_function)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run(argv=sys.argv)
parser = resnet_run_loop.ResnetArgParser(
resnet_size_choices=[18, 34, 50, 101, 152, 200])
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(argv=[sys.argv[0]] + unparsed)
...@@ -22,6 +22,7 @@ import unittest ...@@ -22,6 +22,7 @@ import unittest
import tensorflow as tf import tensorflow as tf
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.utils.testing import integration
tf.logging.set_verbosity(tf.logging.ERROR) tf.logging.set_verbosity(tf.logging.ERROR)
...@@ -242,6 +243,27 @@ class BaseTest(tf.test.TestCase): ...@@ -242,6 +243,27 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(output.shape, (batch_size, num_classes)) self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic(main=imagenet_main.main, extra_flags=['-v', '1'])
def test_imagenet_end_to_end_synthetic_v2(self):
integration.run_synthetic(main=imagenet_main.main, extra_flags=['-v', '2'])
def test_imagenet_end_to_end_synthetic_v1_tiny(self):
integration.run_synthetic(main=imagenet_main.main,
extra_flags=['-v', '1', '-rs', '18'])
def test_imagenet_end_to_end_synthetic_v2_tiny(self):
integration.run_synthetic(main=imagenet_main.main,
extra_flags=['-v', '2', '-rs', '18'])
def test_imagenet_end_to_end_synthetic_v1_huge(self):
integration.run_synthetic(main=imagenet_main.main,
extra_flags=['-v', '1', '-rs', '200'])
def test_imagenet_end_to_end_synthetic_v2_huge(self):
integration.run_synthetic(main=imagenet_main.main,
extra_flags=['-v', '2', '-rs', '200'])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
......
...@@ -436,4 +436,3 @@ class Model(object): ...@@ -436,4 +436,3 @@ class Model(object):
inputs = tf.identity(inputs, 'final_dense') inputs = tf.identity(inputs, 'final_dense')
return inputs return inputs
...@@ -349,7 +349,8 @@ def resnet_main(flags, model_function, input_function): ...@@ -349,7 +349,8 @@ def resnet_main(flags, model_function, input_function):
flags.epochs_per_eval, flags.num_parallel_calls, flags.epochs_per_eval, flags.num_parallel_calls,
flags.multi_gpu) flags.multi_gpu)
classifier.train(input_fn=input_fn_train, hooks=train_hooks) classifier.train(input_fn=input_fn_train, hooks=train_hooks,
max_steps=flags.max_train_steps)
print('Starting to evaluate.') print('Starting to evaluate.')
# Evaluate the model and print results # Evaluate the model and print results
...@@ -357,7 +358,14 @@ def resnet_main(flags, model_function, input_function): ...@@ -357,7 +358,14 @@ def resnet_main(flags, model_function, input_function):
return input_function(False, flags.data_dir, flags.batch_size, return input_function(False, flags.data_dir, flags.batch_size,
1, flags.num_parallel_calls, flags.multi_gpu) 1, flags.num_parallel_calls, flags.multi_gpu)
eval_results = classifier.evaluate(input_fn=input_fn_eval) # flags.max_train_steps is generally associated with testing and profiling.
# As a result it is frequently called with synthetic data, which will
# iterate forever. Passing steps=flags.max_train_steps allows the eval
# (which is generally unimportant in those circumstances) to terminate.
# Note that eval will run for max_train_steps each loop, regardless of the
# global_step count.
eval_results = classifier.evaluate(input_fn=input_fn_eval,
steps=flags.max_train_steps)
print(eval_results) print(eval_results)
...@@ -381,6 +389,6 @@ class ResnetArgParser(argparse.ArgumentParser): ...@@ -381,6 +389,6 @@ class ResnetArgParser(argparse.ArgumentParser):
self.add_argument( self.add_argument(
'--resnet_size', '-rs', type=int, default=50, '--resnet_size', '-rs', type=int, default=50,
choices=resnet_size_choices, choices=resnet_size_choices,
help='[default: %(default)s]The size of the ResNet model to use.', help='[default: %(default)s] The size of the ResNet model to use.',
metavar='<RS>' metavar='<RS>' if resnet_size_choices is None else None
) )
...@@ -54,6 +54,11 @@ Notes about add_argument(): ...@@ -54,6 +54,11 @@ Notes about add_argument():
""" """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse import argparse
...@@ -142,7 +147,7 @@ class PerformanceParser(argparse.ArgumentParser): ...@@ -142,7 +147,7 @@ class PerformanceParser(argparse.ArgumentParser):
""" """
def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True, def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True): intra_op=True, use_synthetic_data=True, max_train_steps=True):
super(PerformanceParser, self).__init__(add_help=add_help) super(PerformanceParser, self).__init__(add_help=add_help)
if num_parallel_calls: if num_parallel_calls:
...@@ -184,6 +189,17 @@ class PerformanceParser(argparse.ArgumentParser): ...@@ -184,6 +189,17 @@ class PerformanceParser(argparse.ArgumentParser):
"input processing steps, but will not learn anything." "input processing steps, but will not learn anything."
) )
if max_train_steps:
self.add_argument(
"--max_train_steps", "-mts", type=int, default=None,
help="[default: %(default)s] The model will stop training if the "
"global_step reaches this value. If not set, training will run"
"until the specified number of epochs have run as usual. It is"
"generally recommended to set --train_epochs=1 when using this"
"flag.",
metavar="<MTS>"
)
class ImageModelParser(argparse.ArgumentParser): class ImageModelParser(argparse.ArgumentParser):
"""Default parser for specification image specific behavior. """Default parser for specification image specific behavior.
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper code to run complete models from within python.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import sys
import time
def run_synthetic(main, extra_flags=None):
"""Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A
very limited run is performed using synthetic data.
Args:
main: The primary function used to excercise a code path. Generally this
function is "<MODULE>.main(argv)".
extra_flags: Additional flags passed by the the caller of this function.
"""
extra_flags = [] if extra_flags is None else extra_flags
model_dir = "/tmp/model_test_{}".format(hash(time.time()))
if os.path.exists(model_dir):
shutil.rmtree(model_dir)
args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1",
"--epochs_per_eval", "1", "--use_synthetic_data",
"--max_train_steps", "1"] + extra_flags
try:
main(args)
finally:
if os.path.exists(model_dir):
shutil.rmtree(model_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment