Unverified Commit dcca6b44 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

Basic end to end test of resnet. (#3598)

This commit adds a basic end to end test for resnet cifar10 and imagenet models to check for syntax errors outside of the core neural net code. 
parent f028970f
......@@ -209,14 +209,7 @@ def cifar10_model_fn(features, labels, mode, params):
multi_gpu=params['multi_gpu'])
def main(unused_argv):
input_function = FLAGS.use_synthetic_data and get_synth_input_fn() or input_fn
resnet_run_loop.resnet_main(FLAGS, cifar10_model_fn, input_function)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
def main(argv):
parser = resnet_run_loop.ResnetArgParser()
# Set defaults that are reasonable for this model.
parser.set_defaults(data_dir='/tmp/cifar10_data',
......@@ -226,5 +219,12 @@ if __name__ == '__main__':
epochs_per_eval=10,
batch_size=128)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(argv=[sys.argv[0]] + unparsed)
flags = parser.parse_args(args=argv[1:])
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn
resnet_run_loop.resnet_main(flags, cifar10_model_fn, input_function)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run(argv=sys.argv)
......@@ -23,6 +23,7 @@ import numpy as np
import tensorflow as tf
from official.resnet import cifar10_main
from official.utils.testing import integration
tf.logging.set_verbosity(tf.logging.ERROR)
......@@ -135,6 +136,12 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic(main=cifar10_main.main, extra_flags=['-v', '1'])
def test_cifar10_end_to_end_synthetic_v2(self):
integration.run_synthetic(main=cifar10_main.main, extra_flags=['-v', '2'])
if __name__ == '__main__':
tf.test.main()
......@@ -281,15 +281,15 @@ def imagenet_model_fn(features, labels, mode, params):
multi_gpu=params['multi_gpu'])
def main(unused_argv):
input_function = FLAGS.use_synthetic_data and get_synth_input_fn() or input_fn
resnet_run_loop.resnet_main(FLAGS, imagenet_model_fn, input_function)
def main(argv):
parser = resnet_run_loop.ResnetArgParser(
resnet_size_choices=[18, 34, 50, 101, 152, 200])
flags = parser.parse_args(args=argv[1:])
input_function = flags.use_synthetic_data and get_synth_input_fn() or input_fn
resnet_run_loop.resnet_main(flags, imagenet_model_fn, input_function)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
parser = resnet_run_loop.ResnetArgParser(
resnet_size_choices=[18, 34, 50, 101, 152, 200])
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(argv=[sys.argv[0]] + unparsed)
tf.app.run(argv=sys.argv)
......@@ -22,6 +22,7 @@ import unittest
import tensorflow as tf
from official.resnet import imagenet_main
from official.utils.testing import integration
tf.logging.set_verbosity(tf.logging.ERROR)
......@@ -242,6 +243,27 @@ class BaseTest(tf.test.TestCase):
self.assertAllEqual(output.shape, (batch_size, num_classes))
def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic(main=imagenet_main.main, extra_flags=['-v', '1'])
def test_imagenet_end_to_end_synthetic_v2(self):
integration.run_synthetic(main=imagenet_main.main, extra_flags=['-v', '2'])
def test_imagenet_end_to_end_synthetic_v1_tiny(self):
integration.run_synthetic(main=imagenet_main.main,
extra_flags=['-v', '1', '-rs', '18'])
def test_imagenet_end_to_end_synthetic_v2_tiny(self):
integration.run_synthetic(main=imagenet_main.main,
extra_flags=['-v', '2', '-rs', '18'])
def test_imagenet_end_to_end_synthetic_v1_huge(self):
integration.run_synthetic(main=imagenet_main.main,
extra_flags=['-v', '1', '-rs', '200'])
def test_imagenet_end_to_end_synthetic_v2_huge(self):
integration.run_synthetic(main=imagenet_main.main,
extra_flags=['-v', '2', '-rs', '200'])
if __name__ == '__main__':
tf.test.main()
......
......@@ -436,4 +436,3 @@ class Model(object):
inputs = tf.identity(inputs, 'final_dense')
return inputs
......@@ -349,7 +349,8 @@ def resnet_main(flags, model_function, input_function):
flags.epochs_per_eval, flags.num_parallel_calls,
flags.multi_gpu)
classifier.train(input_fn=input_fn_train, hooks=train_hooks)
classifier.train(input_fn=input_fn_train, hooks=train_hooks,
max_steps=flags.max_train_steps)
print('Starting to evaluate.')
# Evaluate the model and print results
......@@ -357,7 +358,14 @@ def resnet_main(flags, model_function, input_function):
return input_function(False, flags.data_dir, flags.batch_size,
1, flags.num_parallel_calls, flags.multi_gpu)
eval_results = classifier.evaluate(input_fn=input_fn_eval)
# flags.max_train_steps is generally associated with testing and profiling.
# As a result it is frequently called with synthetic data, which will
# iterate forever. Passing steps=flags.max_train_steps allows the eval
# (which is generally unimportant in those circumstances) to terminate.
# Note that eval will run for max_train_steps each loop, regardless of the
# global_step count.
eval_results = classifier.evaluate(input_fn=input_fn_eval,
steps=flags.max_train_steps)
print(eval_results)
......@@ -381,6 +389,6 @@ class ResnetArgParser(argparse.ArgumentParser):
self.add_argument(
'--resnet_size', '-rs', type=int, default=50,
choices=resnet_size_choices,
help='[default: %(default)s]The size of the ResNet model to use.',
metavar='<RS>'
help='[default: %(default)s] The size of the ResNet model to use.',
metavar='<RS>' if resnet_size_choices is None else None
)
......@@ -54,6 +54,11 @@ Notes about add_argument():
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
......@@ -142,7 +147,7 @@ class PerformanceParser(argparse.ArgumentParser):
"""
def __init__(self, add_help=False, num_parallel_calls=True, inter_op=True,
intra_op=True, use_synthetic_data=True):
intra_op=True, use_synthetic_data=True, max_train_steps=True):
super(PerformanceParser, self).__init__(add_help=add_help)
if num_parallel_calls:
......@@ -184,6 +189,17 @@ class PerformanceParser(argparse.ArgumentParser):
"input processing steps, but will not learn anything."
)
if max_train_steps:
self.add_argument(
"--max_train_steps", "-mts", type=int, default=None,
help="[default: %(default)s] The model will stop training if the "
"global_step reaches this value. If not set, training will run"
"until the specified number of epochs have run as usual. It is"
"generally recommended to set --train_epochs=1 when using this"
"flag.",
metavar="<MTS>"
)
class ImageModelParser(argparse.ArgumentParser):
"""Default parser for specification image specific behavior.
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper code to run complete models from within python.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import shutil
import sys
import time
def run_synthetic(main, extra_flags=None):
"""Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A
very limited run is performed using synthetic data.
Args:
main: The primary function used to excercise a code path. Generally this
function is "<MODULE>.main(argv)".
extra_flags: Additional flags passed by the the caller of this function.
"""
extra_flags = [] if extra_flags is None else extra_flags
model_dir = "/tmp/model_test_{}".format(hash(time.time()))
if os.path.exists(model_dir):
shutil.rmtree(model_dir)
args = [sys.argv[0], "--model_dir", model_dir, "--train_epochs", "1",
"--epochs_per_eval", "1", "--use_synthetic_data",
"--max_train_steps", "1"] + extra_flags
try:
main(args)
finally:
if os.path.exists(model_dir):
shutil.rmtree(model_dir)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment