"...git@developer.sourcefind.cn:OpenDAS/autoawq_kernels.git" did not exist on "b5592bd67a4a3836921d1a93ab90efd3b3e436f7"
Commit cd00b9a7 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Remove unnecessary flags

PiperOrigin-RevId: 276518206
parent 87d6459a
...@@ -180,12 +180,6 @@ def define_transformer_flags(): ...@@ -180,12 +180,6 @@ def define_transformer_flags():
default=False, default=False,
help=flags_core.help_wrap( help=flags_core.help_wrap(
'Whether the model runs with custom training loop.')) 'Whether the model runs with custom training loop.'))
flags.DEFINE_bool(
name='use_tpu_2vm_config',
default=False,
help=flags_core.help_wrap(
'Whether the model runs in 2VM mode, Headless server and unit test '
'all use 1VM config.'))
flags.DEFINE_integer( flags.DEFINE_integer(
name='decode_batch_size', name='decode_batch_size',
default=32, default=32,
......
...@@ -449,22 +449,14 @@ def main(_): ...@@ -449,22 +449,14 @@ def main(_):
with logger.benchmark_context(flags_obj): with logger.benchmark_context(flags_obj):
task = TransformerTask(flags_obj) task = TransformerTask(flags_obj)
def _run_task(task): if flags_obj.mode == "train":
if flags_obj.mode == "train": task.train()
task.train() elif flags_obj.mode == "predict":
elif flags_obj.mode == "predict": task.predict()
task.predict() elif flags_obj.mode == "eval":
elif flags_obj.mode == "eval": task.eval()
task.eval()
else:
raise ValueError("Invalid mode {}".format(flags_obj.mode))
if flags_obj.distribution_strategy != "tpu":
_run_task(task)
else: else:
primary_cpu_task = "/job:worker" if flags_obj.use_tpu_2vm_config else "" raise ValueError("Invalid mode {}".format(flags_obj.mode))
with tf.device(primary_cpu_task):
_run_task(task)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -28,7 +28,7 @@ from absl.testing import flagsaver ...@@ -28,7 +28,7 @@ from absl.testing import flagsaver
import tensorflow as tf import tensorflow as tf
from official.transformer.v2 import misc from official.transformer.v2 import misc
from official.transformer.v2 import transformer_main as tm from official.transformer.v2 import transformer_main
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
...@@ -84,7 +84,7 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -84,7 +84,7 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train_no_dist_strat(self): def test_train_no_dist_strat(self):
if context.num_gpus() >= 2: if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.') self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
t = tm.TransformerTask(FLAGS) t = transformer_main.TransformerTask(FLAGS)
t.train() t.train()
def test_train_static_batch(self): def test_train_static_batch(self):
...@@ -96,20 +96,20 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -96,20 +96,20 @@ class TransformerTaskTest(tf.test.TestCase):
else: else:
FLAGS.num_gpus = 0 FLAGS.num_gpus = 0
FLAGS.static_batch = True FLAGS.static_batch = True
t = tm.TransformerTask(FLAGS) t = transformer_main.TransformerTask(FLAGS)
t.train() t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_train_1_gpu_with_dist_strat(self): def test_train_1_gpu_with_dist_strat(self):
FLAGS.distribution_strategy = 'one_device' FLAGS.distribution_strategy = 'one_device'
t = tm.TransformerTask(FLAGS) t = transformer_main.TransformerTask(FLAGS)
t.train() t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_train_fp16(self): def test_train_fp16(self):
FLAGS.distribution_strategy = 'one_device' FLAGS.distribution_strategy = 'one_device'
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
t = tm.TransformerTask(FLAGS) t = transformer_main.TransformerTask(FLAGS)
t.train() t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
...@@ -121,7 +121,7 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -121,7 +121,7 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
t = tm.TransformerTask(FLAGS) t = transformer_main.TransformerTask(FLAGS)
t.train() t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU') @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
...@@ -134,7 +134,7 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -134,7 +134,7 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.param_set = 'base' FLAGS.param_set = 'base'
FLAGS.dtype = 'fp16' FLAGS.dtype = 'fp16'
t = tm.TransformerTask(FLAGS) t = transformer_main.TransformerTask(FLAGS)
t.train() t.train()
def _prepare_files_and_flags(self, *extra_flags): def _prepare_files_and_flags(self, *extra_flags):
...@@ -167,14 +167,14 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -167,14 +167,14 @@ class TransformerTaskTest(tf.test.TestCase):
if context.num_gpus() >= 2: if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.') self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags() self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS) t = transformer_main.TransformerTask(FLAGS)
t.predict() t.predict()
def test_predict_fp16(self): def test_predict_fp16(self):
if context.num_gpus() >= 2: if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.') self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags('--dtype=fp16') self._prepare_files_and_flags('--dtype=fp16')
t = tm.TransformerTask(FLAGS) t = transformer_main.TransformerTask(FLAGS)
t.predict() t.predict()
def test_eval(self): def test_eval(self):
...@@ -183,7 +183,7 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -183,7 +183,7 @@ class TransformerTaskTest(tf.test.TestCase):
if 'test_xla' in sys.argv[0]: if 'test_xla' in sys.argv[0]:
self.skipTest('TODO(xla): Make this test faster under XLA.') self.skipTest('TODO(xla): Make this test faster under XLA.')
self._prepare_files_and_flags() self._prepare_files_and_flags()
t = tm.TransformerTask(FLAGS) t = transformer_main.TransformerTask(FLAGS)
t.eval() t.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment