".github/vscode:/vscode.git/clone" did not exist on "1df3b64d007b4554dc8b7e481ca500421364c310"
Commit 824ff2d6 authored by Reed Wanderman-Milne's avatar Reed Wanderman-Milne Committed by A. Unique TensorFlower
Browse files

Do not expose --max_train_steps in models that do not use it.

Only the V1 resnet model uses --max_train_steps. This unexposes the flag in the keras_application_models, mnist, keras resnet, CTL resnet Models. Before this change, such models allowed the flag to be specified, but ignored it.

I also removed the "max_train" argument from the run_synthetic function, since this only had any meaning for the V1 resnet model. Instead, the V1 resnet model now directly passes --max_train_steps=1 to run_synthetic.

PiperOrigin-RevId: 264269836
parent b974c3f9
......@@ -133,7 +133,7 @@ class BaseTest(tf.test.TestCase):
"--eval_start", "12",
"--eval_count", "8",
],
synth=False, max_train=None)
synth=False)
self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint")))
@unittest.skipIf(keras_utils.is_v2_0(), "TF 1.0 only test.")
......@@ -152,7 +152,7 @@ class BaseTest(tf.test.TestCase):
"--eval_start", "12",
"--eval_count", "8",
],
synth=False, max_train=None)
synth=False)
self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint")))
self.assertTrue(tf.gfile.Exists(os.path.join(export_dir)))
......
......@@ -168,13 +168,15 @@ class BaseTest(tf.test.TestCase):
def test_cifar10_end_to_end_synthetic_v1(self):
integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-batch_size', '4']
extra_flags=['-resnet_version', '1', '-batch_size', '4',
'--max_train_steps', '1']
)
def test_cifar10_end_to_end_synthetic_v2(self):
integration.run_synthetic(
main=cifar10_main.run_cifar, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-batch_size', '4']
extra_flags=['-resnet_version', '2', '-batch_size', '4',
'--max_train_steps', '1']
)
......
......@@ -282,41 +282,43 @@ class BaseTest(tf.test.TestCase):
def test_imagenet_end_to_end_synthetic_v1(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-batch_size', '4']
extra_flags=['-resnet_version', '1', '-batch_size', '4',
'--max_train_steps', '1']
)
def test_imagenet_end_to_end_synthetic_v2(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-batch_size', '4']
extra_flags=['-resnet_version', '2', '-batch_size', '4',
'--max_train_steps', '1']
)
def test_imagenet_end_to_end_synthetic_v1_tiny(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-batch_size', '4',
'-resnet_size', '18']
'-resnet_size', '18', '--max_train_steps', '1']
)
def test_imagenet_end_to_end_synthetic_v2_tiny(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-batch_size', '4',
'-resnet_size', '18']
'-resnet_size', '18', '--max_train_steps', '1']
)
def test_imagenet_end_to_end_synthetic_v1_huge(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '1', '-batch_size', '4',
'-resnet_size', '200']
'-resnet_size', '200', '--max_train_steps', '1']
)
def test_imagenet_end_to_end_synthetic_v2_huge(self):
integration.run_synthetic(
main=imagenet_main.run_imagenet, tmp_root=self.get_temp_dir(),
extra_flags=['-resnet_version', '2', '-batch_size', '4',
'-resnet_size', '200']
'-resnet_size', '200', '--max_train_steps', '1']
)
......
......@@ -730,7 +730,8 @@ def define_resnet_flags(resnet_size_choices=None, dynamic_loss_scale=False,
dynamic_loss_scale=dynamic_loss_scale,
fp16_implementation=fp16_implementation,
loss_scale=True,
tf_data_experimental_slack=True)
tf_data_experimental_slack=True,
max_train_steps=True)
flags_core.define_image()
flags_core.define_benchmark()
flags_core.define_distribution()
......
......@@ -139,7 +139,7 @@ class BaseTest(tf.test.TestCase):
'--model_type', 'wide',
'--download_if_missing=false'
],
synth=False, max_train=None)
synth=False)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_end_to_end_deep(self):
......@@ -150,7 +150,7 @@ class BaseTest(tf.test.TestCase):
'--model_type', 'deep',
'--download_if_missing=false'
],
synth=False, max_train=None)
synth=False)
@unittest.skipIf(keras_utils.is_v2_0(), 'TF 1.0 only test.')
def test_end_to_end_wide_deep(self):
......@@ -161,7 +161,7 @@ class BaseTest(tf.test.TestCase):
'--model_type', 'wide_deep',
'--download_if_missing=false'
],
synth=False, max_train=None)
synth=False)
if __name__ == '__main__':
......
......@@ -112,7 +112,7 @@ class BaseTest(tf.test.TestCase):
"--train_epochs", "1",
"--epochs_between_evals", "1"
],
synth=False, max_train=None)
synth=False)
if __name__ == "__main__":
......
......@@ -195,20 +195,20 @@ class NcfTest(tf.test.TestCase):
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS)
@unittest.skipIf(keras_utils.is_v2_0(), "TODO(b/136018594)")
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_estimator_mlperf(self):
integration.run_synthetic(
ncf_estimator_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_estimator_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-ml_perf', 'True'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
def test_end_to_end_keras_no_dist_strat(self):
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS +
['-distribution_strategy', 'off'])
......@@ -216,7 +216,7 @@ class NcfTest(tf.test.TestCase):
@unittest.skipUnless(keras_utils.is_v2_0(), 'TF 2.0 only test.')
def test_end_to_end_keras_dist_strat(self):
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '0'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
......@@ -226,7 +226,7 @@ class NcfTest(tf.test.TestCase):
['-num_gpus', '0'] +
['-keras_use_ctl', 'True'])
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=flags)
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
......@@ -238,7 +238,7 @@ class NcfTest(tf.test.TestCase):
format(1, context.num_gpus()))
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '1'])
@mock.patch.object(rconst, "SYNTHETIC_BATCHES_PER_EPOCH", 100)
......@@ -250,7 +250,7 @@ class NcfTest(tf.test.TestCase):
format(2, context.num_gpus()))
integration.run_synthetic(
ncf_keras_main.main, tmp_root=self.get_temp_dir(), max_train=None,
ncf_keras_main.main, tmp_root=self.get_temp_dir(),
extra_flags=self._BASE_END_TO_END_FLAGS + ['-num_gpus', '2'])
if __name__ == "__main__":
......
......@@ -54,7 +54,7 @@ def get_loss_scale(flags_obj, default_for_fp16):
def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True,
synthetic_data=True, max_train_steps=True, dtype=True,
synthetic_data=True, max_train_steps=False, dtype=True,
all_reduce_alg=True, num_packs=True,
tf_gpu_thread_mode=False,
datasets_num_private_threads=False,
......
......@@ -29,7 +29,7 @@ from absl import flags
from official.utils.flags import core as flags_core
def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
def run_synthetic(main, tmp_root, extra_flags=None, synth=True):
"""Performs a minimal run of a model.
This function is intended to test for syntax errors throughout a model. A
......@@ -41,7 +41,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
tmp_root: Root path for the temp directory created by the test class.
extra_flags: Additional flags passed by the caller of this function.
synth: Use synthetic data.
max_train: Maximum number of allowed training steps.
"""
extra_flags = [] if extra_flags is None else extra_flags
......@@ -54,9 +53,6 @@ def run_synthetic(main, tmp_root, extra_flags=None, synth=True, max_train=1):
if synth:
args.append("--use_synthetic_data")
if max_train is not None:
args.extend(["--max_train_steps", str(max_train)])
try:
flags_core.parse_flags(argv=args)
main(flags.FLAGS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment