Unverified Commit a1c47f28 authored by saberkun's avatar saberkun Committed by GitHub
Browse files

Merged commit includes the following changes: (#7049)

253850824  by hongkuny<hongkuny@google.com>:

    Improve bert training utils.

--
253818191  by hongkuny<hongkuny@google.com>:

    Update savedmodel export to use new model.save() api.

--

PiperOrigin-RevId: 253850824
parent 90f8c43b
......@@ -49,11 +49,10 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
"""Base class to hold methods common to test classes in the module."""
def __init__(self, output_dir=None):
super(BertClassifyBenchmarkBase, self).__init__(output_dir)
self.num_epochs = None
self.num_steps_per_epoch = None
super(BertClassifyBenchmarkBase, self).__init__(output_dir)
@flagsaver.flagsaver
def _run_bert_classifier(self, callbacks=None):
"""Starts BERT classification task."""
......@@ -72,6 +71,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy='mirrored', num_gpus=self.num_gpus)
steps_per_loop = 1
run_classifier.run_customized_training(
strategy,
......@@ -80,6 +80,7 @@ class BertClassifyBenchmarkBase(benchmark_utils.BertBenchmarkBase):
FLAGS.model_dir,
epochs,
steps_per_epoch,
steps_per_loop,
eval_steps,
warmup_steps,
FLAGS.learning_rate,
......
......@@ -28,7 +28,7 @@ def export_bert_model(model_export_path,
model=None,
model_fn=None,
checkpoint_dir=None):
"""Export BERT model for serving.
"""Export BERT model for serving which does not include the optimizer.
Arguments:
model_export_path: Path to which exported model will be saved.
......@@ -39,7 +39,7 @@ def export_bert_model(model_export_path,
checkpoint_dir: Path from which model weights will be loaded.
"""
if model:
tf.keras.experimental.export_saved_model(model, model_export_path)
model.save(model_export_path, include_optimizer=False, save_format='tf')
return
assert model_fn and checkpoint_dir
......@@ -50,7 +50,8 @@ def export_bert_model(model_export_path,
logging.info('Checkpoint file %s found and restoring from '
'checkpoint', latest_checkpoint_file)
checkpoint.restore(latest_checkpoint_file).assert_existing_objects_matched()
tf.keras.experimental.export_saved_model(model_to_export, model_export_path)
model_to_export.save(
model_export_path, include_optimizer=False, save_format='tf')
class BertModelCheckpoint(tf.keras.callbacks.Callback):
......
......@@ -58,6 +58,22 @@ def _get_input_iterator(input_fn, strategy):
return iterator
def _float_metric_value(metric):
"""Gets the value of a float-value keras metric."""
return metric.result().numpy().astype(float)
def _steps_to_run(current_step, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 1:
return steps_per_loop
remainder_in_epoch = current_step % steps_per_epoch
if remainder_in_epoch != 0:
return min(steps_per_epoch - remainder_in_epoch, steps_per_loop)
else:
return steps_per_loop
def run_customized_training_loop(
# pylint: disable=invalid-name
_sentinel=None,
......@@ -68,6 +84,7 @@ def run_customized_training_loop(
model_dir=None,
train_input_fn=None,
steps_per_epoch=None,
steps_per_loop=1,
epochs=1,
eval_input_fn=None,
eval_steps=None,
......@@ -90,7 +107,12 @@ def run_customized_training_loop(
model_dir: Model directory used during training for restoring/saving model
weights.
train_input_fn: Function that returns a tf.data.Dataset used for training.
steps_per_epoch: Number of steps to run per epoch.
steps_per_epoch: Number of steps to run per epoch. At the end of each
epoch, model checkpoint will be saved and evaluation will be conducted
if evaluation dataset is provided.
steps_per_loop: Number of steps per graph-mode loop. In order to reduce
communication in eager context, training logs are printed every
steps_per_loop.
epochs: Number of epochs to train.
eval_input_fn: Function that returns evaluation dataset. If none,
evaluation is skipped.
......@@ -125,8 +147,14 @@ def run_customized_training_loop(
]
if [arg for arg in required_arguments if arg is None]:
raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
'and `steps_per_epoch` are required parameters')
'`steps_per_loop` and `steps_per_epoch` are required '
'parameters.')
if steps_per_loop > steps_per_epoch:
logging.error(
'steps_per_loop: %d is specified to be greater than '
' steps_per_epoch: %d, we will use steps_per_epoch as'
' steps_per_loop.', steps_per_loop, steps_per_epoch)
steps_per_loop = steps_per_epoch
assert tf.executing_eagerly()
if eval_input_fn and (eval_steps is None or metric_fn is None):
......@@ -161,12 +189,14 @@ def run_customized_training_loop(
checkpoint.restore(init_checkpoint).assert_consumed()
logging.info('Loading from checkpoint file completed')
metric = metric_fn() if metric_fn else None
train_loss_metric = tf.keras.metrics.Mean(
'training_loss', dtype=tf.float32)
eval_metric = metric_fn() if metric_fn else None
# If evaluation is required, make a copy of metric as it will be used by
# both train and evaluation.
train_metric = (
metric.__class__.from_config(metric.get_config())
if metric else None)
eval_metric.__class__.from_config(eval_metric.get_config())
if eval_metric else None)
@tf.function
def train_step(iterator):
......@@ -179,20 +209,16 @@ def run_customized_training_loop(
with tf.GradientTape() as tape:
model_outputs = model(inputs)
loss = loss_fn(labels, model_outputs)
if train_metric:
train_metric.update_state(labels, model_outputs)
tvars = model.trainable_variables
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(zip(grads, tvars))
return loss
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
if train_metric:
train_metric.update_state(labels, model_outputs)
per_replica_losses = strategy.experimental_run_v2(
_replicated_step, args=(next(iterator),))
# For reporting, we returns the mean of losses.
loss = strategy.reduce(
tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
return loss
strategy.experimental_run_v2(_replicated_step, args=(next(iterator),))
@tf.function
def test_step(iterator):
......@@ -203,7 +229,7 @@ def run_customized_training_loop(
inputs, labels = inputs
model_outputs = model(inputs, training=False)
metric.update_state(labels, model_outputs)
eval_metric.update_state(labels, model_outputs)
strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
......@@ -211,11 +237,8 @@ def run_customized_training_loop(
"""Runs validation steps and aggregate metrics."""
for _ in range(eval_steps):
test_step(test_iterator)
metric_result = metric.result().numpy().astype(float)
logging.info('Step: [%d] Validation metric = %f', current_training_step,
metric_result)
return metric_result
_float_metric_value(eval_metric))
def _run_callbacks_on_batch_begin(batch):
"""Runs custom callbacks at the start of every step."""
......@@ -244,25 +267,29 @@ def run_customized_training_loop(
current_step = optimizer.iterations.numpy()
checkpoint_name = 'ctl_step_{step}.ckpt'
train_metric_result = None
eval_metric_result = None
train_loss = None
while current_step < total_training_steps:
current_step += 1
_run_callbacks_on_batch_begin(current_step)
train_loss = train_step(train_iterator).numpy().astype(float)
# Training loss/metric are taking average over steps inside micro
# training loop. We reset the their values before each round.
train_loss_metric.reset_states()
if train_metric:
train_metric_result = train_metric.result().numpy().astype(float)
train_metric.reset_states()
logging.info('Train Step: %d/%d / loss = %s / training metric = %s',
current_step, total_training_steps, train_loss,
train_metric_result)
else:
logging.info('Train Step: %d/%d / loss = %s', current_step,
total_training_steps, train_loss)
state_step = current_step
_run_callbacks_on_batch_begin(state_step)
for _ in range(
_steps_to_run(state_step, steps_per_epoch, steps_per_loop)):
current_step += 1
train_step(train_iterator)
_run_callbacks_on_batch_end(state_step)
_run_callbacks_on_batch_end(current_step)
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps,
_float_metric_value(train_loss_metric))
if train_metric:
training_status += ' training metric = %s' % _float_metric_value(
train_metric)
logging.info(training_status)
# Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0:
......@@ -276,31 +303,29 @@ def run_customized_training_loop(
logging.info('Running evaluation after step: %s.', current_step)
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric, except the last step.
if metric and current_step < total_training_steps:
metric.reset_states()
train_metric.reset_states()
# Re-initialize evaluation metric.
eval_metric.reset_states()
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if eval_input_fn:
logging.info('Running final evaluation after training is complete.')
eval_metric_result = _run_evaluation(
current_step, _get_input_iterator(eval_input_fn, strategy))
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
training_summary = {
'total_training_steps': total_training_steps,
'train_loss': train_loss
'train_loss': _float_metric_value(train_loss_metric),
}
if train_metric_result:
training_summary['train_metrics'] = train_metric_result
if eval_metric_result:
training_summary['eval_metrics'] = eval_metric_result
if eval_metric:
training_summary['last_train_metrics'] = _float_metric_value(
train_metric)
training_summary['eval_metrics'] = _float_metric_value(eval_metric)
summary_path = os.path.join(model_dir, SUMMARY_TXT)
with tf.io.gfile.GFile(summary_path, 'wb') as f:
logging.info('Training Summary: \n%s', str(training_summary))
f.write(json.dumps(training_summary, indent=4))
return model
......@@ -74,6 +74,11 @@ flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 8, 'Batch size for evaluation.')
flags.DEFINE_integer('num_train_epochs', 3,
'Total number of training epochs to perform.')
flags.DEFINE_integer(
'steps_per_loop', 200,
'Number of steps per graph-mode loop. Only training step '
'happens inside the loop. Callbacks will not be called '
'inside.')
flags.DEFINE_float('learning_rate', 5e-5, 'The initial learning rate for Adam.')
FLAGS = flags.FLAGS
......@@ -103,6 +108,7 @@ def run_customized_training(strategy,
model_dir,
epochs,
steps_per_epoch,
steps_per_loop,
eval_steps,
warmup_steps,
initial_lr,
......@@ -148,6 +154,7 @@ def run_customized_training(strategy,
loss_fn=loss_fn,
model_dir=model_dir,
steps_per_epoch=steps_per_epoch,
steps_per_loop=steps_per_loop,
epochs=epochs,
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
......@@ -204,19 +211,25 @@ def run_bert(strategy, input_meta_data):
logging.info('Training using customized training loop TF 2.0 with distrubuted'
'strategy.')
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
return run_customized_training(
trained_model = run_customized_training(
strategy,
bert_config,
input_meta_data,
FLAGS.model_dir,
epochs,
steps_per_epoch,
FLAGS.steps_per_loop,
eval_steps,
warmup_steps,
FLAGS.learning_rate,
FLAGS.init_checkpoint,
use_remote_tpu=use_remote_tpu)
if FLAGS.model_export_path:
model_saving_utils.export_bert_model(
FLAGS.model_export_path, model=trained_model)
return trained_model
def main(_):
# Users should always run this script under TF 2.x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment