Commit d967bfae authored by guptapriya's avatar guptapriya Committed by guptapriya
Browse files

add strategy specific tests

parent 3aee5697
......@@ -97,6 +97,12 @@ class TransformerTask(object):
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_core.get_num_gpus(flags_obj))
print("Running transformer with num_gpus =", num_gpus)
if self.distribution_strategy:
print("For training, using distribution strategy: ", self.distribution_strategy)
else:
print("Not using any distribution strategy.")
self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)
params["num_gpus"] = num_gpus
......
......@@ -49,6 +49,8 @@ class TransformerTaskTest(tf.test.TestCase):
FLAGS.train_steps = 2
FLAGS.validation_steps = 1
FLAGS.batch_size = 8
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = "off"
self.model_dir = FLAGS.model_dir
self.temp_dir = temp_dir
self.vocab_file = os.path.join(temp_dir, "vocab")
......@@ -62,7 +64,23 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train(self):
t = tm.TransformerTask(FLAGS)
t.train()
def test_train_static_batch(self):
FLAGS.static_batch = True
t = tm.TransformerTask(FLAGS)
t.train()
def test_train_1_gpu_with_dist_strat(self):
FLAGS.distribution_strategy = "one_device"
t = tm.TransformerTask(FLAGS)
t.train()
def test_train_2_gpu(self):
FLAGS.distribution_strategy = "mirrored"
FLAGS.num_gpus = 2
t = tm.TransformerTask(FLAGS)
t.train()
def _prepare_files_and_flags(self, *extra_flags):
# Make log dir.
if not os.path.exists(self.temp_dir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment