"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "c081ad0d554a3807aa81607ed8aa2ac139b49f39"
Commit 2a56bb7e authored by guptapriya's avatar guptapriya Committed by guptapriya
Browse files

fix distributed tests

parent d967bfae
...@@ -265,14 +265,15 @@ def _read_and_batch_from_files( ...@@ -265,14 +265,15 @@ def _read_and_batch_from_files(
def _generate_synthetic_data(params): def _generate_synthetic_data(params):
"""Create synthetic data based on the parameter batch size.""" """Create synthetic data based on the parameter batch size."""
batch = length = int(math.sqrt(params["batch_size"])) batch = length = int(math.sqrt(params["batch_size"]))
return model_helpers.generate_synthetic_data( dataset = model_helpers.generate_synthetic_data(
input_shape=tf.TensorShape([batch, length]), input_shape=tf.TensorShape([length]),
input_value=1, input_value=1,
input_dtype=tf.int32, input_dtype=tf.int32,
label_shape=tf.TensorShape([batch, length]), label_shape=tf.TensorShape([length]),
label_value=1, label_value=1,
label_dtype=tf.int32, label_dtype=tf.int32,
) )
return dataset.batch(batch)
def train_input_fn(params): def train_input_fn(params):
......
...@@ -43,7 +43,7 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -43,7 +43,7 @@ class TransformerTaskTest(tf.test.TestCase):
def setUp(self): def setUp(self):
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP) FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP)
FLAGS.param_set = param_set = "tiny" FLAGS.param_set = "tiny"
FLAGS.use_synthetic_data = True FLAGS.use_synthetic_data = True
FLAGS.steps_between_evals = 1 FLAGS.steps_between_evals = 1
FLAGS.train_steps = 2 FLAGS.train_steps = 2
...@@ -54,7 +54,7 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -54,7 +54,7 @@ class TransformerTaskTest(tf.test.TestCase):
self.model_dir = FLAGS.model_dir self.model_dir = FLAGS.model_dir
self.temp_dir = temp_dir self.temp_dir = temp_dir
self.vocab_file = os.path.join(temp_dir, "vocab") self.vocab_file = os.path.join(temp_dir, "vocab")
self.vocab_size = misc.get_model_params(param_set, 0)["vocab_size"] self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)["vocab_size"]
self.bleu_source = os.path.join(temp_dir, "bleu_source") self.bleu_source = os.path.join(temp_dir, "bleu_source")
self.bleu_ref = os.path.join(temp_dir, "bleu_ref") self.bleu_ref = os.path.join(temp_dir, "bleu_ref")
...@@ -78,6 +78,7 @@ class TransformerTaskTest(tf.test.TestCase): ...@@ -78,6 +78,7 @@ class TransformerTaskTest(tf.test.TestCase):
def test_train_2_gpu(self): def test_train_2_gpu(self):
FLAGS.distribution_strategy = "mirrored" FLAGS.distribution_strategy = "mirrored"
FLAGS.num_gpus = 2 FLAGS.num_gpus = 2
FLAGS.param_set = "base"
t = tm.TransformerTask(FLAGS) t = tm.TransformerTask(FLAGS)
t.train() t.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment