Commit e6750c5d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Fix xlnet save_steps and steps_per_epoch conflicts

Remove train_data_size flag

PiperOrigin-RevId: 275545035
parent d59bcd47
...@@ -59,7 +59,7 @@ def _float_metric_value(metric): ...@@ -59,7 +59,7 @@ def _float_metric_value(metric):
return metric.result().numpy().astype(float) return metric.result().numpy().astype(float)
def _steps_to_run(current_step, steps_per_epoch, steps_per_loop): def steps_to_run(current_step, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device.""" """Calculates steps to run on device."""
if steps_per_loop <= 0: if steps_per_loop <= 0:
raise ValueError('steps_per_loop should be positive integer.') raise ValueError('steps_per_loop should be positive integer.')
...@@ -353,7 +353,7 @@ def run_customized_training_loop( ...@@ -353,7 +353,7 @@ def run_customized_training_loop(
_run_callbacks_on_batch_begin(current_step) _run_callbacks_on_batch_begin(current_step)
# Runs several steps in the host while loop. # Runs several steps in the host while loop.
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop) steps = steps_to_run(current_step, steps_per_epoch, steps_per_loop)
if steps == 1: if steps == 1:
# TODO(zongweiz): merge with train_steps once tf.while_loop # TODO(zongweiz): merge with train_steps once tf.while_loop
......
...@@ -127,8 +127,6 @@ flags.DEFINE_float( ...@@ -127,8 +127,6 @@ flags.DEFINE_float(
flags.DEFINE_float( flags.DEFINE_float(
"init_range", default=0.1, help="Initialization std when init is uniform.") "init_range", default=0.1, help="Initialization std when init is uniform.")
flags.DEFINE_integer(
"train_data_size", default=130738, help="Number of training data samples.")
flags.DEFINE_integer( flags.DEFINE_integer(
"test_data_size", default=12048, help="Number of test data samples.") "test_data_size", default=12048, help="Number of test data samples.")
flags.DEFINE_string( flags.DEFINE_string(
......
...@@ -147,7 +147,6 @@ def main(unused_argv): ...@@ -147,7 +147,6 @@ def main(unused_argv):
strategy, False, FLAGS.test_tfrecord_path) strategy, False, FLAGS.test_tfrecord_path)
total_training_steps = FLAGS.train_steps total_training_steps = FLAGS.train_steps
steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
steps_per_loop = FLAGS.iterations steps_per_loop = FLAGS.iterations
eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size) eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
eval_fn = functools.partial(run_evaluation, strategy, test_input_fn, eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
...@@ -181,7 +180,6 @@ def main(unused_argv): ...@@ -181,7 +180,6 @@ def main(unused_argv):
init_checkpoint=FLAGS.init_checkpoint, init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl, init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps, total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop, steps_per_loop=steps_per_loop,
optimizer=optimizer, optimizer=optimizer,
learning_rate_fn=learning_rate_fn, learning_rate_fn=learning_rate_fn,
......
...@@ -105,7 +105,7 @@ def main(unused_argv): ...@@ -105,7 +105,7 @@ def main(unused_argv):
num_hosts) num_hosts)
total_training_steps = FLAGS.train_steps total_training_steps = FLAGS.train_steps
steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
steps_per_loop = FLAGS.iterations steps_per_loop = FLAGS.iterations
optimizer, learning_rate_fn = optimization.create_optimizer( optimizer, learning_rate_fn = optimization.create_optimizer(
...@@ -139,7 +139,6 @@ def main(unused_argv): ...@@ -139,7 +139,6 @@ def main(unused_argv):
init_checkpoint=FLAGS.init_checkpoint, init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl, init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps, total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop, steps_per_loop=steps_per_loop,
optimizer=optimizer, optimizer=optimizer,
learning_rate_fn=learning_rate_fn, learning_rate_fn=learning_rate_fn,
......
...@@ -239,7 +239,6 @@ def main(unused_argv): ...@@ -239,7 +239,6 @@ def main(unused_argv):
FLAGS.test_tfrecord_path) FLAGS.test_tfrecord_path)
total_training_steps = FLAGS.train_steps total_training_steps = FLAGS.train_steps
steps_per_epoch = int(FLAGS.train_data_size / FLAGS.train_batch_size)
steps_per_loop = FLAGS.iterations steps_per_loop = FLAGS.iterations
eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size) eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
...@@ -286,7 +285,6 @@ def main(unused_argv): ...@@ -286,7 +285,6 @@ def main(unused_argv):
init_checkpoint=FLAGS.init_checkpoint, init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl, init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps, total_training_steps=total_training_steps,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop, steps_per_loop=steps_per_loop,
optimizer=optimizer, optimizer=optimizer,
learning_rate_fn=learning_rate_fn, learning_rate_fn=learning_rate_fn,
......
...@@ -49,26 +49,12 @@ def _float_metric_value(metric): ...@@ -49,26 +49,12 @@ def _float_metric_value(metric):
return metric.result().numpy().astype(float) return metric.result().numpy().astype(float)
def _steps_to_run(current_step, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 0:
raise ValueError("steps_per_loop should be positive integer.")
if steps_per_loop == 1:
return steps_per_loop
remainder_in_epoch = current_step % steps_per_epoch
if remainder_in_epoch != 0:
return min(steps_per_epoch - remainder_in_epoch, steps_per_loop)
else:
return steps_per_loop
def train( def train(
strategy: tf.distribute.Strategy, strategy: tf.distribute.Strategy,
model_fn: Callable, model_fn: Callable,
input_meta_data: Dict, input_meta_data: Dict,
train_input_fn: Callable, train_input_fn: Callable,
total_training_steps: int, total_training_steps: int,
steps_per_epoch: int,
steps_per_loop: int, steps_per_loop: int,
optimizer: tf.keras.optimizers.Optimizer, optimizer: tf.keras.optimizers.Optimizer,
learning_rate_fn: tf.keras.optimizers.schedules.LearningRateSchedule, learning_rate_fn: tf.keras.optimizers.schedules.LearningRateSchedule,
...@@ -90,9 +76,6 @@ def train( ...@@ -90,9 +76,6 @@ def train(
`n_layer`, `batch_size_per_core` and `d_model`. `n_layer`, `batch_size_per_core` and `d_model`.
train_input_fn: Function returns a tf.data.Dataset used for training. train_input_fn: Function returns a tf.data.Dataset used for training.
total_training_steps: Number of steps to train in total. total_training_steps: Number of steps to train in total.
steps_per_epoch: Number of steps to run per epoch. At the end of each
epoch, model checkpoint will be saved and evaluation will be conducted
if evaluation dataset is provided.
steps_per_loop: Number of steps per graph-mode loop. In order to reduce steps_per_loop: Number of steps per graph-mode loop. In order to reduce
communication in eager context, training logs are printed every communication in eager context, training logs are printed every
steps_per_loop. steps_per_loop.
...@@ -111,7 +94,8 @@ def train( ...@@ -111,7 +94,8 @@ def train(
`model_fn`. `model_fn`.
model_dir: The directory of model (checkpoints, summaries). model_dir: The directory of model (checkpoints, summaries).
save_steps: The frequency to save checkpoints. Every save_steps, we save a save_steps: The frequency to save checkpoints. Every save_steps, we save a
model checkpoint. model checkpoint. Model checkpoint will be saved and evaluation will be
conducted if evaluation dataset is provided.
run_eagerly: Whether to run training eagerly. run_eagerly: Whether to run training eagerly.
Returns: Returns:
...@@ -120,12 +104,12 @@ def train( ...@@ -120,12 +104,12 @@ def train(
TypeError: if model directory is not specified. TypeError: if model directory is not specified.
""" """
required_arguments = [ required_arguments = [
train_input_fn, total_training_steps, steps_per_epoch, steps_per_loop, train_input_fn, total_training_steps, steps_per_loop, optimizer,
optimizer, learning_rate_fn learning_rate_fn, save_steps
] ]
if [arg for arg in required_arguments if arg is None]: if [arg for arg in required_arguments if arg is None]:
raise ValueError("`train_input_fn`, `total_training_steps`, " raise ValueError("`train_input_fn`, `total_training_steps`, "
"`steps_per_epoch`, `steps_per_loop`, `optimizer` and " "`steps_per_loop`, `optimizer`, `save_steps` and "
"`learning_rate_fn` are required parameters.") "`learning_rate_fn` are required parameters.")
if not model_dir: if not model_dir:
raise TypeError("Model directory must be specified.") raise TypeError("Model directory must be specified.")
...@@ -159,7 +143,7 @@ def train( ...@@ -159,7 +143,7 @@ def train(
transformer_xl=model.transformerxl_model) transformer_xl=model.transformerxl_model)
else: else:
checkpoint = tf.train.Checkpoint(model=model) checkpoint = tf.train.Checkpoint(model=model)
checkpoint.restore(init_checkpoint) checkpoint.restore(init_checkpoint).assert_existing_objects_matched()
model.optimizer = optimizer model.optimizer = optimizer
...@@ -274,7 +258,8 @@ def train( ...@@ -274,7 +258,8 @@ def train(
if train_metric: if train_metric:
train_metric.reset_states() train_metric.reset_states()
steps = _steps_to_run(current_step, steps_per_epoch, steps_per_loop) steps = model_training_utils.steps_to_run(current_step, save_steps,
steps_per_loop)
train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32)) train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
current_step += steps current_step += steps
train_loss = _float_metric_value(train_loss_metric) train_loss = _float_metric_value(train_loss_metric)
...@@ -299,13 +284,11 @@ def train( ...@@ -299,13 +284,11 @@ def train(
_float_metric_value(train_metric), _float_metric_value(train_metric),
step=current_step) step=current_step)
train_summary_writer.flush() train_summary_writer.flush()
if model_dir: if model_dir and current_step % save_steps == 0:
if (save_steps is None) or (save_steps and _save_checkpoint(checkpoint, model_dir,
current_step % save_steps == 0): checkpoint_name.format(step=current_step))
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if test_input_fn and current_step % steps_per_epoch == 0: if test_input_fn and current_step % save_steps == 0:
logging.info("Running evaluation after step: %s.", current_step) logging.info("Running evaluation after step: %s.", current_step)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment