"...resnet50_tensorflow.git" did not exist on "67d65c69a4367d221fcf7303d3ae65451da42e15"
Commit 824ff2d6 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Do not expose --max_train_steps in models that do not use it.

Only the V1 resnet model uses --max_train_steps. This unexposes the flag in the keras_application_models, mnist, keras resnet, CTL resnet Models. Before this change, such models allowed the flag to be specified, but ignored it.

I also removed the "max_train" argument from the run_synthetic function, since this only had any meaning for the V1 resnet model. Instead, the V1 resnet model now directly passes --max_train_steps=1 to run_synthetic.

PiperOrigin-RevId: 264269836
parent b974c3f9
...@@ -133,7 +133,7 @@ class BaseTest(tf.test.TestCase): ...@@ -133,7 +133,7 @@ class BaseTest(tf.test.TestCase):
"--eval_start", "12", "--eval_start", "12",
"--eval_count", "8", "--eval_count", "8",
], ],
synth=False, max_train=None) synth=False)
self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint"))) self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint")))
@unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.") @unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.")
...@@ -152,7 +152,7 @@ class BaseTest(tf.test.TestCase): ...@@ -152,7 +152,7 @@ class BaseTest(tf.test.TestCase):
"--eval_start", "12", "--eval_start", "12",
"--eval_count", "8", "--eval_count", "8",
], ],
synth=False, max_train=None) synth=False)
self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint"))) self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint")))
self.assertTrue(tf.gfile.Exists(os.path.join(export_dir))) self.assertTrue(tf.gfile.Exists(os.path.join(export_dir)))
......
...@@ -168,13 +168,15 @@ class BaseTest(tf.test.TestCase): ...@@ -168,13 +168,15 @@ class BaseTest(tf.test.TestCase):
def test_cifar10_end_to_end_synthetic_v1(self): def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(), main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-batch_size', '4'] extra_flags=['-resnet_version', '1', '-batch_size', '4',
'--max_train_steps', '1']
) )
def test_cifar10_end_to_end_synthetic_v2(self): def test_cifar10_end_to_end_synthetic_v2(self):
integration.run_synthetic( integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(), main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-batch_size', '4'] extra_flags=['-resnet_version', '2', '-batch_size', '4',
'--max_train_steps', '1']
) )
......
...@@ -282,41 +282,43 @@ class BaseTest(tf.test.TestCase): ...@@ -282,41 +282,43 @@ class BaseTest(tf.test.TestCase):
def test_imagenet_end_to_end_synthetic_v1(self): def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-batch_size', '4'] extra_flags=['-resnet_version', '1', '-batch_size', '4',
'--max_train_steps', '1']
) )
def test_imagenet_end_to_end_synthetic_v2(self): def test_imagenet_end_to_end_synthetic_v2(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-batch_size', '4'] extra_flags=['-resnet_version', '2', '-batch_size', '4',
'--max_train_steps', '1']
) )
def test_imagenet_end_to_end_synthetic_v1_tiny(self): def test_imagenet_end_to_end_synthetic_v1_tiny(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-batch_size', '4', extra_flags=['-resnet_version', '1', '-batch_size', '4',
'-resnet_size', '18'] '-resnet_size', '18', '--max_train_steps', '1']
) )
def test_imagenet_end_to_end_synthetic_v2_tiny(self): def test_imagenet_end_to_end_synthetic_v2_tiny(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-batch_size', '4', extra_flags=['-resnet_version', '2', '-batch_size', '4',
'-resnet_size', '18'] '-resnet_size', '18', '--max_train_steps', '1']
) )
def test_imagenet_end_to_end_synthetic_v1_huge(self): def test_imagenet_end_to_end_synthetic_v1_huge(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-batch_size', '4', extra_flags=['-resnet_version', '1', '-batch_size', '4',
'-resnet_size', '200'] '-resnet_size', '200', '--max_train_steps', '1']
) )
def test_imagenet_end_to_end_synthetic_v2_huge(self): def test_imagenet_end_to_end_synthetic_v2_huge(self):
integration.run_synthetic( integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(), main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-batch_size', '4', extra_flags=['-resnet_version', '2', '-batch_size', '4',
'-resnet_size', '200'] '-resnet_size', '200', '--max_train_steps', '1']
) )
......
...@@ -730,7 +730,8 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False, ...@@ -730,7 +730,8 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
dynamic_loss_scale=dynamic_loss_scale, dynamic_loss_scale=dynamic_loss_scale,
fp16_implementation=fp16_implementation, fp16_implementation=fp16_implementation,
loss_scale=True, loss_scale=True,
tf_data_experimental_slack=True) tf_data_experimental_slack=True,
max_train_steps=True)
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
flags_core.define_distribution() flags_core.define_distribution()
......
...@@ -139,7 +139,7 @@ class BaseTest(tf.test.TestCase): ...@@ -139,7 +139,7 @@ class BaseTest(tf.test.TestCase):
'--model_type', 'wide', '--model_type', 'wide',
'--download_if_missing=false' '--download_if_missing=false'
], ],
synth=False, max_train=None) synth=False)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.') @unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_end_to_end_deep(self): def test_end_to_end_deep(self):
...@@ -150,7 +150,7 @@ class BaseTest(tf.test.TestCase): ...@@ -150,7 +150,7 @@ class BaseTest(tf.test.TestCase):
'--model_type', 'deep', '--model_type', 'deep',
'--download_if_missing=false' '--download_if_missing=false'
], ],
synth=False, max_train=None) synth=False)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.') @unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_end_to_end_wide_deep(self): def test_end_to_end_wide_deep(self):
...@@ -161,7 +161,7 @@ class BaseTest(tf.test.TestCase): ...@@ -161,7 +161,7 @@ class BaseTest(tf.test.TestCase):
'--model_type', 'wide_deep', '--model_type', 'wide_deep',
'--download_if_missing=false' '--download_if_missing=false'
], ],
synth=False, max_train=None) synth=False)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -112,7 +112,7 @@ class BaseTest(tf.test.TestCase): ...@@ -112,7 +112,7 @@ class BaseTest(tf.test.TestCase):
"--train_epochs", "1", "--train_epochs", "1",
"--epochs_between_evals", "1" "--epochs_between_evals", "1"
], ],
synth=False, max_train=None) synth=False)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -195,20 +195,20 @@ class NcfTest(tf.test.TestCase): ...@@ -195,20 +195,20 @@ class NcfTest(tf.test.TestCase):
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator(self): def test_end_to_end_estimator(self):
integration.run_synthetic( integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS) extra_flags=self._BASE_END_TO_END_FLAGS)
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)") @unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self): def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic( integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_no_dist_strat(self): def test_end_to_end_keras_no_dist_strat(self):
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + extra_flags=self._BASE_END_TO_END_FLAGS +
['-distribution_strategy', 'off']) ['-distribution_strategy', 'off'])
...@@ -216,7 +216,7 @@ class NcfTest(tf.test.TestCase): ...@@ -216,7 +216,7 @@ class NcfTest(tf.test.TestCase):
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.') @unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat(self): def test_end_to_end_keras_dist_strat(self):
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
...@@ -226,7 +226,7 @@ class NcfTest(tf.test.TestCase): ...@@ -226,7 +226,7 @@ class NcfTest(tf.test.TestCase):
['-num_gpus', '0'] + ['-num_gpus', '0'] +
['-keras_use_ctl', 'True']) ['-keras_use_ctl', 'True'])
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=flags) extra_flags=flags)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
...@@ -238,7 +238,7 @@ class NcfTest(tf.test.TestCase): ...@@ -238,7 +238,7 @@ class NcfTest(tf.test.TestCase):
format(1, context.num_gpus())) format(1, context.num_gpus()))
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100) @mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
...@@ -250,7 +250,7 @@ class NcfTest(tf.test.TestCase): ...@@ -250,7 +250,7 @@ class NcfTest(tf.test.TestCase):
format(2, context.num_gpus())) format(2, context.num_gpus()))
integration.run_synthetic( integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None, ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '2']) extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '2'])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -54,7 +54,7 @@ def get_loss_scale(flags_obj, default_for_fp16): ...@@ -54,7 +54,7 @@ def get_loss_scale(flags_obj, default_for_fp16):
def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
synthetic_data=True, max_train_steps=True, dtype=True, synthetic_data=True, max_train_steps=False, dtype=True,
all_reduce_alg=True, num_packs=True, all_reduce_alg=True, num_packs=True,
tf_gpu_thread_mode=False, tf_gpu_thread_mode=False,
datasets_num_private_threads=False, datasets_num_private_threads=False,
......
...@@ -29,7 +29,7 @@ from absl import flags ...@@ -29,7 +29,7 @@ from absl import flags
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1): def run_synthetic(main, tmp_root, extra_flags=None, synth=True):
"""Performs a minimal run of a model. """Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A This function is intended to test for syntax errors throughout a model. A
...@@ -41,7 +41,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1): ...@@ -41,7 +41,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
tmp_root: Root path for the temp directory created by the test class. tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the caller of this function. extra_flags: Additional flags passed by the caller of this function.
synth: Use synthetic data. synth: Use synthetic data.
max_train: Maximum number of allowed training steps.
""" """
extra_flags = [] if extra_flags is None else extra_flags extra_flags = [] if extra_flags is None else extra_flags
...@@ -54,9 +53,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1): ...@@ -54,9 +53,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
if synth: if synth:
args.append("--use_synthetic_data") args.append("--use_synthetic_data")
if max_train is not None:
args.extend(["--max_train_steps", str(max_train)])
try: try:
flags_core.parse_flags(argv=args) flags_core.parse_flags(argv=args)
main(flags.FLAGS) main(flags.FLAGS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment