Commit bf6d6f6f authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Adds scale_loss to run_customized_training_loop, which should be the correct treatment.

PiperOrigin-RevId: 301441181
parent 4f58b1f7
......@@ -102,6 +102,7 @@ def run_customized_training_loop(
strategy=None,
model_fn=None,
loss_fn=None,
scale_loss=True,
model_dir=None,
train_input_fn=None,
steps_per_epoch=None,
......@@ -129,6 +130,8 @@ def run_customized_training_loop(
to be used for initial checkpoint -- if provided.
loss_fn: Function with signature func(labels, logits) and returns a loss
tensor.
scale_loss: Whether to divide the raw loss by number of replicas before
gradients calculation.
model_dir: Model directory used during training for restoring/saving model
weights.
train_input_fn: Function that returns a tf.data.Dataset used for training.
......@@ -284,6 +287,12 @@ def run_customized_training_loop(
with tf.GradientTape() as tape:
model_outputs = model(inputs, training=True)
loss = loss_fn(labels, model_outputs)
# Raw loss is used for reporting in metrics/logs.
raw_loss = loss
if scale_loss:
# Scales down the loss for gradients to be invariant from replicas.
loss = loss / strategy.num_replicas_in_sync
if explicit_allreduce:
grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
training_vars,
......@@ -300,7 +309,7 @@ def run_customized_training_loop(
grads = tape.gradient(loss, training_vars)
optimizer.apply_gradients(zip(grads, training_vars))
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
train_loss_metric.update_state(raw_loss)
for metric in train_metrics:
metric.update_state(labels, model_outputs)
......
......@@ -61,7 +61,7 @@ common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
def get_loss_fn(num_classes, loss_factor=1.0):
def get_loss_fn(num_classes):
"""Gets the classification loss function."""
def classification_loss_fn(labels, logits):
......@@ -72,9 +72,7 @@ def get_loss_fn(num_classes, loss_factor=1.0):
tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
per_example_loss = -tf.reduce_sum(
tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
loss *= loss_factor
return loss
return tf.reduce_mean(per_example_loss)
return classification_loss_fn
......@@ -135,17 +133,7 @@ def run_bert_classifier(strategy,
use_graph_rewrite=common_flags.use_graph_rewrite())
return classifier_model, core_model
# During distributed training, loss used for gradient computation is
# summed over from all replicas. When Keras compile/fit() API is used,
# the fit() API internally normalizes the loss by dividing the loss by
# the number of replicas used for computation. However, when custom
# training loop is used this is not done automatically and should be
# done manually by the end user.
loss_multiplier = 1.0
if FLAGS.scale_loss and not use_keras_compile_fit:
loss_multiplier = 1.0 / strategy.num_replicas_in_sync
loss_fn = get_loss_fn(num_classes, loss_factor=loss_multiplier)
loss_fn = get_loss_fn(num_classes)
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
......
......@@ -74,11 +74,11 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
return _dataset_fn
def get_loss_fn(loss_factor=1.0):
def get_loss_fn():
"""Returns loss function for BERT pretraining."""
def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
return tf.reduce_mean(losses) * loss_factor
return tf.reduce_mean(losses)
return _bert_pretrain_loss_fn
......@@ -116,9 +116,8 @@ def run_customized_training(strategy,
trained_model = model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_pretrain_model,
loss_fn=get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0),
loss_fn=get_loss_fn(),
scale_loss=FLAGS.scale_loss,
model_dir=model_dir,
train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch,
......
......@@ -90,8 +90,7 @@ FLAGS = flags.FLAGS
def squad_loss_fn(start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=1.0):
end_logits):
"""Returns sparse categorical crossentropy for start/end logits."""
start_loss = tf.keras.losses.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True)
......@@ -99,11 +98,10 @@ def squad_loss_fn(start_positions,
end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
total_loss *= loss_factor
return total_loss
def get_loss_fn(loss_factor=1.0):
def get_loss_fn():
"""Gets a loss function for squad task."""
def _loss_fn(labels, model_outputs):
......@@ -114,8 +112,7 @@ def get_loss_fn(loss_factor=1.0):
start_positions,
end_positions,
start_logits,
end_logits,
loss_factor=loss_factor)
end_logits)
return _loss_fn
......@@ -249,14 +246,6 @@ def train_squad(strategy,
use_graph_rewrite=common_flags.use_graph_rewrite())
return squad_model, core_model
# The original BERT model does not scale the loss by
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
loss_fn = get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
# If explicit_allreduce = True, apply_gradients() no longer implicitly
# allreduce gradients, users manually allreduce gradient and pass the
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm will be
......@@ -269,7 +258,7 @@ def train_squad(strategy,
model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_squad_model,
loss_fn=loss_fn,
loss_fn=get_loss_fn(),
model_dir=FLAGS.model_dir,
steps_per_epoch=steps_per_epoch,
steps_per_loop=FLAGS.steps_per_loop,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment