Commit e6750c5d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix xlnet save_steps and steps_per_epoch conflicts

Remove train_data_size flag

PiperOrigin-RevId: 275545035
parent d59bcd47
......@@ -59,7 +59,7 @@ def _float_metric_value(metric):
return metric.result().numpy().astype(float)
def _steps_to_run(current_step, steps_per_epoch, steps_per_loop):
def steps_to_run(current_step, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 0:
raise ValueError('steps_per_loop should be positive integer.')
......@@ -353,7 +353,7 @@ def run_customized_training_loop(
_run_callbacks_on_batch_begin(current_step)
# Runs several steps in the host while loop.
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if steps == 1:
# TODO(zongweiz): merge with train_steps once tf.while_loop
......
......@@ -127,8 +127,6 @@ flags.DEFINE_float(
flags.DEFINE_float(
"init_range", default=0.1, help="Initialization std when init is uniform.")
flags.DEFINE_integer(
"train_data_size", default=130738, help="Number of training data samples.")
flags.DEFINE_integer(
"test_data_size", default=12048, help="Number of test data samples.")
flags.DEFINE_string(
......
......@@ -147,7 +147,6 @@ def main(unused_argv):
strategy, False, FLAGS.test_tfrecord_path)
total_training_steps = FLAGS.train_steps
steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
steps_per_loop = FLAGS.iterations
eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
......@@ -181,7 +180,6 @@ def main(unused_argv):
init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
......
......@@ -105,7 +105,7 @@ def main(unused_argv):
num_hosts)
total_training_steps = FLAGS.train_steps
steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
steps_per_loop = FLAGS.iterations
optimizer, learning_rate_fn = optimization.create_optimizer(
......@@ -139,7 +139,6 @@ def main(unused_argv):
init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
......
......@@ -239,7 +239,6 @@ def main(unused_argv):
FLAGS.test_tfrecord_path)
total_training_steps = FLAGS.train_steps
steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
steps_per_loop = FLAGS.iterations
eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
......@@ -286,7 +285,6 @@ def main(unused_argv):
init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
optimizer=optimizer,
learning_rate_fn=learning_rate_fn,
......
......@@ -49,26 +49,12 @@ def _float_metric_value(metric):
return metric.result().numpy().astype(float)
def _steps_to_run(current_step, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 0:
raise ValueError("steps_per_loop should be positive integer.")
if steps_per_loop == 1:
return steps_per_loop
remainder_in_epoch = current_step % steps_per_epoch
if remainder_in_epoch != 0:
return min(steps_per_epoch - remainder_in_epoch, steps_per_loop)
else:
return steps_per_loop
def train(
strategy: tf.distribute.Strategy,
model_fn: Callable,
input_meta_data: Dict,
train_input_fn: Callable,
total_training_steps: int,
steps_per_epoch: int,
steps_per_loop: int,
optimizer: tf.keras.optimizers.Optimizer,
learning_rate_fn: tf.keras.optimizers.schedules.LearningRateSchedule,
......@@ -90,9 +76,6 @@ def train(
`n_layer`, `batch_size_per_core` and `d_model`.
train_input_fn: Function returns a tf.data.Dataset used for training.
total_training_steps: Number of steps to train in total.
steps_per_epoch: Number of steps to run per epoch. At the end of each
epoch, model checkpoint will be saved and evaluation will be conducted
if evaluation dataset is provided.
steps_per_loop: Number of steps per graph-mode loop. In order to reduce
communication in eager context, training logs are printed every
steps_per_loop.
......@@ -111,7 +94,8 @@ def train(
`model_fn`.
model_dir: The directory of model (checkpoints, summaries).
save_steps: The frequency to save checkpoints. Every save_steps, we save a
model checkpoint.
model checkpoint. Model checkpoint will be saved and evaluation will be
conducted if evaluation dataset is provided.
run_eagerly: Whether to run training eagerly.
Returns:
......@@ -120,12 +104,12 @@ def train(
TypeError: if model directory is not specified.
"""
required_arguments = [
train_input_fn, total_training_steps, steps_per_epoch, steps_per_loop,
optimizer, learning_rate_fn
train_input_fn, total_training_steps, steps_per_loop, optimizer,
learning_rate_fn, save_steps
]
if [arg for arg in required_arguments if arg is None]:
raise ValueError("`train_input_fn`, `total_training_steps`, "
"`steps_per_epoch`, `steps_per_loop`, `optimizer` and "
"`steps_per_loop`, `optimizer`, `save_steps` and "
"`learning_rate_fn` are required parameters.")
if not model_dir:
raise TypeError("Model directory must be specified.")
......@@ -159,7 +143,7 @@ def train(
transformer_xl=model.transformerxl_model)
else:
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(init_checkpoint)
checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
model.optimizer = optimizer
......@@ -274,7 +258,8 @@ def train(
if train_metric:
train_metric.reset_states()
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop)
steps = model_training_utils.steps_to_run(current_step, save_steps,
steps_per_loop)
train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
current_step += steps
train_loss = _float_metric_value(train_loss_metric)
......@@ -299,13 +284,11 @@ def train(
_float_metric_value(train_metric),
step=current_step)
train_summary_writer.flush()
if model_dir:
if (save_steps is None) or (save_steps and
current_step % save_steps == 0):
if model_dir and current_step % save_steps == 0:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if test_input_fn and current_step % steps_per_epoch == 0:
if test_input_fn and current_step % save_steps == 0:
logging.info("Running evaluation after step: %s.", current_step)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment