Unverified Commit a1c47f28 authored by saberkun's avatar saberkun Committed by GitHub
Browse files

Merged commit includes the following changes: (#7049)

253850824  by hongkuny<hongkuny@google.com>:

    Improve bert training utils.

--
253818191  by hongkuny<hongkuny@google.com>:

    Update savedmodel export to use new model.save() api.

--

PiperOrigin-RevId: 253850824
parent 90f8c43b
...@@ -49,11 +49,10 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -49,11 +49,10 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module.""" """Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None): def __init__(self, output_dir=None):
super(BertClassifyBenchmarkBase, self).__init__(output_dir)
self.num_epochs = None self.num_epochs = None
self.num_steps_per_epoch = None self.num_steps_per_epoch = None
super(BertClassifyBenchmarkBase, self).__init__(output_dir)
@flagsaver.flagsaver @flagsaver.flagsaver
def _run_bert_classifier(self, callbacks=None): def _run_bert_classifier(self, callbacks=None):
"""Starts BERT classification task.""" """Starts BERT classification task."""
...@@ -72,6 +71,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -72,6 +71,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size)) math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored', num_gpus=self.num_gpus) distribution_strategy='mirrored', num_gpus=self.num_gpus)
steps_per_loop = 1
run_classifier.run_customized_training( run_classifier.run_customized_training(
strategy, strategy,
...@@ -80,6 +80,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase): ...@@ -80,6 +80,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
FLAGS.model_dir, FLAGS.model_dir,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
steps_per_loop,
eval_steps, eval_steps,
warmup_steps, warmup_steps,
FLAGS.learning_rate, FLAGS.learning_rate,
......
...@@ -28,7 +28,7 @@ def export_bert_model(model_export_path, ...@@ -28,7 +28,7 @@ def export_bert_model(model_export_path,
model=None, model=None,
model_fn=None, model_fn=None,
checkpoint_dir=None): checkpoint_dir=None):
"""Export BERT model for serving. """Export BERT model for serving which does not include the optimizer.
Arguments: Arguments:
model_export_path: Path to which exported model will be saved. model_export_path: Path to which exported model will be saved.
...@@ -39,7 +39,7 @@ def export_bert_model(model_export_path, ...@@ -39,7 +39,7 @@ def export_bert_model(model_export_path,
checkpoint_dir: Path from which model weights will be loaded. checkpoint_dir: Path from which model weights will be loaded.
""" """
if model: if model:
tf.keras.experimental.export_saved_model(model, model_export_path) model.save(model_export_path, include_optimizer=False, save_format='tf')
return return
assert model_fn and checkpoint_dir assert model_fn and checkpoint_dir
...@@ -50,7 +50,8 @@ def export_bert_model(model_export_path, ...@@ -50,7 +50,8 @@ def export_bert_model(model_export_path,
logging.info('Checkpoint file %s found and restoring from ' logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file) 'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file).assert_existing_objects_matched() checkpoint.restore(latest_checkpoint_file).assert_existing_objects_matched()
tf.keras.experimental.export_saved_model(model_to_export, model_export_path) model_to_export.save(
model_export_path, include_optimizer=False, save_format='tf')
class BertModelCheckpoint(tf.keras.callbacks.Callback): class BertModelCheckpoint(tf.keras.callbacks.Callback):
......
...@@ -58,6 +58,22 @@ def _get_input_iterator(input_fn, strategy): ...@@ -58,6 +58,22 @@ def _get_input_iterator(input_fn, strategy):
return iterator return iterator
def _float_metric_value(metric):
"""Gets the value of a float-value keras metric."""
return metric.result().numpy().astype(float)
def _steps_to_run(current_step, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 1:
return steps_per_loop
remainder_in_epoch = current_step % steps_per_epoch
if remainder_in_epoch != 0:
return min(steps_per_epoch - remainder_in_epoch, steps_per_loop)
else:
return steps_per_loop
def run_customized_training_loop( def run_customized_training_loop(
# pylint: disable=invalid-name # pylint: disable=invalid-name
_sentinel=None, _sentinel=None,
...@@ -68,6 +84,7 @@ def run_customized_training_loop( ...@@ -68,6 +84,7 @@ def run_customized_training_loop(
model_dir=None, model_dir=None,
train_input_fn=None, train_input_fn=None,
steps_per_epoch=None, steps_per_epoch=None,
steps_per_loop=1,
epochs=1, epochs=1,
eval_input_fn=None, eval_input_fn=None,
eval_steps=None, eval_steps=None,
...@@ -90,7 +107,12 @@ def run_customized_training_loop( ...@@ -90,7 +107,12 @@ def run_customized_training_loop(
model_dir: Model directory used during training for restoring/saving model model_dir: Model directory used during training for restoring/saving model
weights. weights.
train_input_fn: Function that returns a tf.data.Dataset used for training. train_input_fn: Function that returns a tf.data.Dataset used for training.
steps_per_epoch: Number of steps to run per epoch. steps_per_epoch: Number of steps to run per epoch. At the end of each
epoch, model checkpoint will be saved and evaluation will be conducted
if evaluation dataset is provided.
steps_per_loop: Number of steps per graph-mode loop. In order to reduce
communication in eager context, training logs are printed every
steps_per_loop.
epochs: Number of epochs to train. epochs: Number of epochs to train.
eval_input_fn: Function that returns evaluation dataset. If none, eval_input_fn: Function that returns evaluation dataset. If none,
evaluation is skipped. evaluation is skipped.
...@@ -125,8 +147,14 @@ def run_customized_training_loop( ...@@ -125,8 +147,14 @@ def run_customized_training_loop(
] ]
if [arg for arg in required_arguments if arg is None]: if [arg for arg in required_arguments if arg is None]:
raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, ' raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
'and `steps_per_epoch` are required parameters') '`steps_per_loop` and `steps_per_epoch` are required '
'parameters.')
if steps_per_loop > steps_per_epoch:
logging.error(
'steps_per_loop: %d is specified to be greater than '
' steps_per_epoch: %d, we will use steps_per_epoch as'
' steps_per_loop.', steps_per_loop, steps_per_epoch)
steps_per_loop = steps_per_epoch
assert tf.executing_eagerly() assert tf.executing_eagerly()
if eval_input_fn and (eval_steps is None or metric_fn is None): if eval_input_fn and (eval_steps is None or metric_fn is None):
...@@ -161,12 +189,14 @@ def run_customized_training_loop( ...@@ -161,12 +189,14 @@ def run_customized_training_loop(
checkpoint.restore(init_checkpoint).assert_consumed() checkpoint.restore(init_checkpoint).assert_consumed()
logging.info('Loading from checkpoint file completed') logging.info('Loading from checkpoint file completed')
metric = metric_fn() if metric_fn else None train_loss_metric = tf.keras.metrics.Mean(
'training_loss', dtype=tf.float32)
eval_metric = metric_fn() if metric_fn else None
# If evaluation is required, make a copy of metric as it will be used by # If evaluation is required, make a copy of metric as it will be used by
# both train and evaluation. # both train and evaluation.
train_metric = ( train_metric = (
metric.__class__.from_config(metric.get_config()) eval_metric.__class__.from_config(eval_metric.get_config())
if metric else None) if eval_metric else None)
@tf.function @tf.function
def train_step(iterator): def train_step(iterator):
...@@ -179,20 +209,16 @@ def run_customized_training_loop( ...@@ -179,20 +209,16 @@ def run_customized_training_loop(
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
model_outputs = model(inputs) model_outputs = model(inputs)
loss = loss_fn(labels, model_outputs) loss = loss_fn(labels, model_outputs)
if train_metric:
train_metric.update_state(labels, model_outputs)
tvars = model.trainable_variables tvars = model.trainable_variables
grads = tape.gradient(loss, tvars) grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(zip(grads, tvars)) optimizer.apply_gradients(zip(grads, tvars))
return loss # For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
if train_metric:
train_metric.update_state(labels, model_outputs)
per_replica_losses = strategy.experimental_run_v2( strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
_replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses.
loss = strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
return loss
@tf.function @tf.function
def test_step(iterator): def test_step(iterator):
...@@ -203,7 +229,7 @@ def run_customized_training_loop( ...@@ -203,7 +229,7 @@ def run_customized_training_loop(
inputs, labels = inputs inputs, labels = inputs
model_outputs = model(inputs, training=False) model_outputs = model(inputs, training=False)
metric.update_state(labels, model_outputs) eval_metric.update_state(labels, model_outputs)
strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),)) strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
...@@ -211,11 +237,8 @@ def run_customized_training_loop( ...@@ -211,11 +237,8 @@ def run_customized_training_loop(
"""Runs validation steps and aggregate metrics.""" """Runs validation steps and aggregate metrics."""
for _ in range(eval_steps): for _ in range(eval_steps):
test_step(test_iterator) test_step(test_iterator)
metric_result = metric.result().numpy().astype(float)
logging.info('Step: [%d] Validation metric = %f', current_training_step, logging.info('Step: [%d] Validation metric = %f', current_training_step,
metric_result) _float_metric_value(eval_metric))
return metric_result
def _run_callbacks_on_batch_begin(batch): def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step.""" """Runs custom callbacks at the start of every step."""
...@@ -244,25 +267,29 @@ def run_customized_training_loop( ...@@ -244,25 +267,29 @@ def run_customized_training_loop(
current_step = optimizer.iterations.numpy() current_step = optimizer.iterations.numpy()
checkpoint_name = 'ctl_step_{step}.ckpt' checkpoint_name = 'ctl_step_{step}.ckpt'
train_metric_result = None
eval_metric_result = None
train_loss = None
while current_step < total_training_steps: while current_step < total_training_steps:
current_step += 1 # Training loss/metric are taking average over steps inside micro
_run_callbacks_on_batch_begin(current_step) # training loop. We reset the their values before each round.
train_loss = train_step(train_iterator).numpy().astype(float) train_loss_metric.reset_states()
if train_metric: if train_metric:
train_metric_result = train_metric.result().numpy().astype(float) train_metric.reset_states()
logging.info('Train Step: %d/%d / loss = %s / training metric = %s', state_step = current_step
current_step, total_training_steps, train_loss, _run_callbacks_on_batch_begin(state_step)
train_metric_result) for _ in range(
else: _steps_to_run(state_step, steps_per_epoch, steps_per_loop)):
logging.info('Train Step: %d/%d / loss = %s', current_step, current_step += 1
total_training_steps, train_loss) train_step(train_iterator)
_run_callbacks_on_batch_end(state_step)
_run_callbacks_on_batch_end(current_step) # Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps,
_float_metric_value(train_loss_metric))
if train_metric:
training_status += ' training metric = %s' % _float_metric_value(
train_metric)
logging.info(training_status)
# Saves model checkpoints and run validation steps at every epoch end. # Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0: if current_step % steps_per_epoch == 0:
...@@ -276,31 +303,29 @@ def run_customized_training_loop( ...@@ -276,31 +303,29 @@ def run_customized_training_loop(
logging.info('Running evaluation after step: %s.', current_step) logging.info('Running evaluation after step: %s.', current_step)
_run_evaluation(current_step, _run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy)) _get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric.
# Re-initialize evaluation metric, except the last step. eval_metric.reset_states()
if metric and current_step < total_training_steps:
metric.reset_states()
train_metric.reset_states()
_save_checkpoint(checkpoint, model_dir, _save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
if eval_input_fn: if eval_input_fn:
logging.info('Running final evaluation after training is complete.') logging.info('Running final evaluation after training is complete.')
eval_metric_result = _run_evaluation( _run_evaluation(current_step,
current_step, _get_input_iterator(eval_input_fn, strategy)) _get_input_iterator(eval_input_fn, strategy))
training_summary = { training_summary = {
'total_training_steps': total_training_steps, 'total_training_steps': total_training_steps,
'train_loss': train_loss 'train_loss': _float_metric_value(train_loss_metric),
} }
if train_metric_result: if eval_metric:
training_summary['train_metrics'] = train_metric_result training_summary['last_train_metrics'] = _float_metric_value(
if eval_metric_result: train_metric)
training_summary['eval_metrics'] = eval_metric_result training_summary['eval_metrics'] = _float_metric_value(eval_metric)
summary_path = os.path.join(model_dir, SUMMARY_TXT) summary_path = os.path.join(model_dir, SUMMARY_TXT)
with tf.io.gfile.GFile(summary_path, 'wb') as f: with tf.io.gfile.GFile(summary_path, 'wb') as f:
logging.info('Training Summary: \n%s', str(training_summary))
f.write(json.dumps(training_summary, indent=4)) f.write(json.dumps(training_summary, indent=4))
return model return model
...@@ -74,6 +74,11 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.') ...@@ -74,6 +74,11 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 8, 'Batch size for evaluation.') flags.DEFINE_integer('eval_batch_size', 8, 'Batch size for evaluation.')
flags.DEFINE_integer('num_train_epochs', 3, flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.') 'Total number of training epochs to perform.')
flags.DEFINE_integer(
'steps_per_loop', 200,
'Number of steps per graph-mode loop. Only training step '
'happens inside the loop. Callbacks will not be called '
'inside.')
flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.') flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -103,6 +108,7 @@ def run_customized_training(strategy, ...@@ -103,6 +108,7 @@ def run_customized_training(strategy,
model_dir, model_dir,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
steps_per_loop,
eval_steps, eval_steps,
warmup_steps, warmup_steps,
initial_lr, initial_lr,
...@@ -148,6 +154,7 @@ def run_customized_training(strategy, ...@@ -148,6 +154,7 @@ def run_customized_training(strategy,
loss_fn=loss_fn, loss_fn=loss_fn,
model_dir=model_dir, model_dir=model_dir,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
epochs=epochs, epochs=epochs,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn, eval_input_fn=eval_input_fn,
...@@ -204,19 +211,25 @@ def run_bert(strategy, input_meta_data): ...@@ -204,19 +211,25 @@ def run_bert(strategy, input_meta_data):
logging.info('Training using customized training loop TF 2.0 with distrubuted' logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.') 'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu) use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
return run_customized_training( trained_model = run_customized_training(
strategy, strategy,
bert_config, bert_config,
input_meta_data, input_meta_data,
FLAGS.model_dir, FLAGS.model_dir,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
FLAGS.steps_per_loop,
eval_steps, eval_steps,
warmup_steps, warmup_steps,
FLAGS.learning_rate, FLAGS.learning_rate,
FLAGS.init_checkpoint, FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu) use_remote_tpu=use_remote_tpu)
if FLAGS.model_export_path:
model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
return trained_model
def main(_): def main(_):
# Users should always run this script under TF 2.x # Users should always run this script under TF 2.x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment