Unverified Commit 9af0aad1 authored by Taylor Robie's avatar Taylor Robie Committed by GitHub
Browse files

fix treatment of main() and absl flags (#4159)

* fix treatment of main() and absl flags

* add mnist_eager, and delint

* add mnist_eager docstring
parent 5f9f6b84
...@@ -185,7 +185,13 @@ def validate_batch_size_for_multi_gpu(batch_size): ...@@ -185,7 +185,13 @@ def validate_batch_size_for_multi_gpu(batch_size):
raise ValueError(err) raise ValueError(err)
def main(flags_obj): def run_mnist(flags_obj):
"""Run MNIST training and eval loop.
Args:
flags_obj: An object containing parsed flag values.
"""
model_function = model_fn model_function = model_fn
if flags_obj.multi_gpu: if flags_obj.multi_gpu:
...@@ -251,6 +257,10 @@ def main(flags_obj): ...@@ -251,6 +257,10 @@ def main(flags_obj):
mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn) mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn)
def main(_):
run_mnist(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
define_mnist_flags() define_mnist_flags()
......
...@@ -98,7 +98,12 @@ def test(model, dataset): ...@@ -98,7 +98,12 @@ def test(model, dataset):
tf.contrib.summary.scalar('accuracy', accuracy.result()) tf.contrib.summary.scalar('accuracy', accuracy.result())
def main(flags_obj): def run_mnist_eager(flags_obj):
"""Run MNIST training and eval loop in eager mode.
Args:
flags_obj: An object containing parsed flag values.
"""
tf.enable_eager_execution() tf.enable_eager_execution()
# Automatically determine device and data_format # Automatically determine device and data_format
...@@ -192,6 +197,11 @@ def define_mnist_eager_flags(): ...@@ -192,6 +197,11 @@ def define_mnist_eager_flags():
train_epochs=10, train_epochs=10,
) )
def main(_):
run_mnist_eager(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
define_mnist_eager_flags() define_mnist_eager_flags()
absl_app.run(main=main) absl_app.run(main=main)
...@@ -238,7 +238,12 @@ def define_cifar_flags(): ...@@ -238,7 +238,12 @@ def define_cifar_flags():
batch_size=128) batch_size=128)
def main(flags_obj): def run_cifar(flags_obj):
"""Run ResNet CIFAR-10 training and eval loop.
Args:
flags_obj: An object containing parsed flag values.
"""
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn() input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
or input_fn) or input_fn)
...@@ -247,6 +252,10 @@ def main(flags_obj): ...@@ -247,6 +252,10 @@ def main(flags_obj):
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS]) shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS])
def main(_):
run_cifar(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
define_cifar_flags() define_cifar_flags()
......
...@@ -177,13 +177,13 @@ class BaseTest(tf.test.TestCase): ...@@ -177,13 +177,13 @@ class BaseTest(tf.test.TestCase):
def test_cifar10_end_to_end_synthetic_v1(self): def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
main=cifar10_main.main, tmp_root=self.get_temp_dir(), main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1'] extra_flags=['-v', '1']
) )
def test_cifar10_end_to_end_synthetic_v2(self): def test_cifar10_end_to_end_synthetic_v2(self):
integration.run_synthetic( integration.run_synthetic(
main=cifar10_main.main, tmp_root=self.get_temp_dir(), main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2'] extra_flags=['-v', '2']
) )
......
...@@ -313,7 +313,12 @@ def define_imagenet_flags(): ...@@ -313,7 +313,12 @@ def define_imagenet_flags():
flags_core.set_defaults(train_epochs=100) flags_core.set_defaults(train_epochs=100)
def main(flags_obj): def run_imagenet(flags_obj):
"""Run ResNet ImageNet training and eval loop.
Args:
flags_obj: An object containing parsed flag values.
"""
input_function = (flags_obj.use_synthetic_data and get_synth_input_fn() input_function = (flags_obj.use_synthetic_data and get_synth_input_fn()
or input_fn) or input_fn)
...@@ -322,6 +327,10 @@ def main(flags_obj): ...@@ -322,6 +327,10 @@ def main(flags_obj):
shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS]) shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS])
def main(_):
run_imagenet(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
define_imagenet_flags() define_imagenet_flags()
......
...@@ -289,37 +289,37 @@ class BaseTest(tf.test.TestCase): ...@@ -289,37 +289,37 @@ class BaseTest(tf.test.TestCase):
def test_imagenet_end_to_end_synthetic_v1(self): def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1'] extra_flags=['-v', '1']
) )
def test_imagenet_end_to_end_synthetic_v2(self): def test_imagenet_end_to_end_synthetic_v2(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2'] extra_flags=['-v', '2']
) )
def test_imagenet_end_to_end_synthetic_v1_tiny(self): def test_imagenet_end_to_end_synthetic_v1_tiny(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1', '-rs', '18'] extra_flags=['-v', '1', '-rs', '18']
) )
def test_imagenet_end_to_end_synthetic_v2_tiny(self): def test_imagenet_end_to_end_synthetic_v2_tiny(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2', '-rs', '18'] extra_flags=['-v', '2', '-rs', '18']
) )
def test_imagenet_end_to_end_synthetic_v1_huge(self): def test_imagenet_end_to_end_synthetic_v1_huge(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '1', '-rs', '200'] extra_flags=['-v', '1', '-rs', '200']
) )
def test_imagenet_end_to_end_synthetic_v2_huge(self): def test_imagenet_end_to_end_synthetic_v2_huge(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.main, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-v', '2', '-rs', '200'] extra_flags=['-v', '2', '-rs', '200']
) )
......
...@@ -217,7 +217,13 @@ def export_model(model, model_type, export_dir): ...@@ -217,7 +217,13 @@ def export_model(model, model_type, export_dir):
model.export_savedmodel(export_dir, example_input_fn) model.export_savedmodel(export_dir, example_input_fn)
def main(flags_obj): def run_wide_deep(flags_obj):
"""Run Wide-Deep training and eval loop.
Args:
flags_obj: An object containing parsed flag values.
"""
# Clean up the model directory if present # Clean up the model directory if present
shutil.rmtree(flags_obj.model_dir, ignore_errors=True) shutil.rmtree(flags_obj.model_dir, ignore_errors=True)
model = build_estimator(flags_obj.model_dir, flags_obj.model_type) model = build_estimator(flags_obj.model_dir, flags_obj.model_type)
...@@ -260,6 +266,10 @@ def main(flags_obj): ...@@ -260,6 +266,10 @@ def main(flags_obj):
export_model(model, flags_obj.model_type, flags_obj.export_dir) export_model(model, flags_obj.model_type, flags_obj.export_dir)
def main(_):
run_wide_deep(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
define_wide_deep_flags() define_wide_deep_flags()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment