Commit ff138931 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 312754139
parent b1eddf4f
......@@ -60,6 +60,7 @@ def define_flags():
"Initial checkpoint (usually from a pre-trained BERT model).")
flags.DEFINE_integer("train_steps", 100000, "Max train steps")
flags.DEFINE_integer("eval_steps", 32, "Number of eval steps per run.")
flags.DEFINE_integer("eval_timeout", 3000, "Timeout waiting for checkpoints.")
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
flags.DEFINE_integer("eval_batch_size", 4, "Total batch size for evaluation.")
flags.DEFINE_integer(
......@@ -203,7 +204,7 @@ def run():
if "train" in FLAGS.mode:
stats = train(params, strategy)
if "eval" in FLAGS.mode:
timeout = 0 if FLAGS.mode == "train_and_eval" else 3000
timeout = 0 if FLAGS.mode == "train_and_eval" else FLAGS.eval_timeout
# Uses padded decoding for TPU. Always uses cache.
padded_decode = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
params.override({
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment