Commit 2a56bb7e authored by guptapriya's avatar guptapriya Committed by guptapriya
Browse files

fix distributed tests

parent d967bfae
......@@ -265,14 +265,15 @@ def _read_and_batch_from_files(
def _generate_synthetic_data(params):
"""Create synthetic data based on the parameter batch size."""
batch = length = int(math.sqrt(params["batch_size"]))
return model_helpers.generate_synthetic_data(
input_shape=tf.TensorShape([batch, length]),
dataset = model_helpers.generate_synthetic_data(
input_shape=tf.TensorShape([length]),
input_value=1,
input_dtype=tf.int32,
label_shape=tf.TensorShape([batch, length]),
label_shape=tf.TensorShape([length]),
label_value=1,
label_dtype=tf.int32,
)
return dataset.batch(batch)
def train_input_fn(params):
......
......@@ -43,7 +43,7 @@ class TransformerTaskTest(tf.test.TestCase):
def setUp(self):
temp_dir = self.get_temp_dir()
FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP)
FLAGS.param_set = param_set = "tiny"
FLAGS.param_set = "tiny"
FLAGS.use_synthetic_data = True
FLAGS.steps_between_evals = 1
FLAGS.train_steps = 2
......@@ -54,7 +54,7 @@ class TransformerTaskTest(tf.test.TestCase):
self.model_dir = FLAGS.model_dir
self.temp_dir = temp_dir
self.vocab_file = os.path.join(temp_dir, "vocab")
self.vocab_size = misc.get_model_params(param_set, 0)["vocab_size"]
self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)["vocab_size"]
self.bleu_source = os.path.join(temp_dir, "bleu_source")
self.bleu_ref = os.path.join(temp_dir, "bleu_ref")
......@@ -78,6 +78,7 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train_2_gpu(self):
FLAGS.distribution_strategy = "mirrored"
FLAGS.num_gpus = 2
FLAGS.param_set = "base"
t = tm.TransformerTask(FLAGS)
t.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment