Commit bf6d6f6f authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Adds scale_loss to run_customized_training_loop, which should be the correct treatment.

PiperOrigin-RevId: 301441181
parent 4f58b1f7
...@@ -102,6 +102,7 @@ def run_customized_training_loop( ...@@ -102,6 +102,7 @@ def run_customized_training_loop(
strategy=None, strategy=None,
model_fn=None, model_fn=None,
loss_fn=None, loss_fn=None,
scale_loss=True,
model_dir=None, model_dir=None,
train_input_fn=None, train_input_fn=None,
steps_per_epoch=None, steps_per_epoch=None,
...@@ -129,6 +130,8 @@ def run_customized_training_loop( ...@@ -129,6 +130,8 @@ def run_customized_training_loop(
to be used for initial checkpoint -- if provided. to be used for initial checkpoint -- if provided.
loss_fn: Function with signature func(labels, logits) and returns a loss loss_fn: Function with signature func(labels, logits) and returns a loss
tensor. tensor.
scale_loss: Whether to divide the raw loss by number of replicas before
gradients calculation.
model_dir: Model directory used during training for restoring/saving model model_dir: Model directory used during training for restoring/saving model
weights. weights.
train_input_fn: Function that returns a tf.data.Dataset used for training. train_input_fn: Function that returns a tf.data.Dataset used for training.
...@@ -284,6 +287,12 @@ def run_customized_training_loop( ...@@ -284,6 +287,12 @@ def run_customized_training_loop(
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
model_outputs = model(inputs, training=True) model_outputs = model(inputs, training=True)
loss = loss_fn(labels, model_outputs) loss = loss_fn(labels, model_outputs)
# Raw loss is used for reporting in metrics/logs.
raw_loss = loss
if scale_loss:
# Scales down the loss for gradients to be invariant from replicas.
loss = loss / strategy.num_replicas_in_sync
if explicit_allreduce: if explicit_allreduce:
grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss, grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
training_vars, training_vars,
...@@ -300,7 +309,7 @@ def run_customized_training_loop( ...@@ -300,7 +309,7 @@ def run_customized_training_loop(
grads = tape.gradient(loss, training_vars) grads = tape.gradient(loss, training_vars)
optimizer.apply_gradients(zip(grads, training_vars)) optimizer.apply_gradients(zip(grads, training_vars))
# For reporting, the metric takes the mean of losses. # For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss) train_loss_metric.update_state(raw_loss)
for metric in train_metrics: for metric in train_metrics:
metric.update_state(labels, model_outputs) metric.update_state(labels, model_outputs)
......
...@@ -61,7 +61,7 @@ common_flags.define_common_bert_flags() ...@@ -61,7 +61,7 @@ common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def get_loss_fn(num_classes, loss_factor=1.0): def get_loss_fn(num_classes):
"""Gets the classification loss function.""" """Gets the classification loss function."""
def classification_loss_fn(labels, logits): def classification_loss_fn(labels, logits):
...@@ -72,9 +72,7 @@ def get_loss_fn(num_classes, loss_factor=1.0): ...@@ -72,9 +72,7 @@ def get_loss_fn(num_classes, loss_factor=1.0):
tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32) tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
per_example_loss = -tf.reduce_sum( per_example_loss = -tf.reduce_sum(
tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1) tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss) return tf.reduce_mean(per_example_loss)
loss *= loss_factor
return loss
return classification_loss_fn return classification_loss_fn
...@@ -135,17 +133,7 @@ def run_bert_classifier(strategy, ...@@ -135,17 +133,7 @@ def run_bert_classifier(strategy,
use_graph_rewrite=common_flags.use_graph_rewrite()) use_graph_rewrite=common_flags.use_graph_rewrite())
return classifier_model, core_model return classifier_model, core_model
# During distributed training, loss used for gradient computation is loss_fn = get_loss_fn(num_classes)
# summed over from all replicas. When Keras compile/fit() API is used,
# the fit() API internally normalizes the loss by dividing the loss by
# the number of replicas used for computation. However, when custom
# training loop is used this is not done automatically and should be
# done manually by the end user.
loss_multiplier = 1.0
if FLAGS.scale_loss and not use_keras_compile_fit:
loss_multiplier = 1.0 / strategy.num_replicas_in_sync
loss_fn = get_loss_fn(num_classes, loss_factor=loss_multiplier)
# Defines evaluation metrics function, which will create metrics in the # Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope. # correct device and strategy scope.
......
...@@ -74,11 +74,11 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length, ...@@ -74,11 +74,11 @@ def get_pretrain_dataset_fn(input_file_pattern, seq_length,
return _dataset_fn return _dataset_fn
def get_loss_fn(loss_factor=1.0): def get_loss_fn():
"""Returns loss function for BERT pretraining.""" """Returns loss function for BERT pretraining."""
def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args): def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
return tf.reduce_mean(losses) * loss_factor return tf.reduce_mean(losses)
return _bert_pretrain_loss_fn return _bert_pretrain_loss_fn
...@@ -116,9 +116,8 @@ def run_customized_training(strategy, ...@@ -116,9 +116,8 @@ def run_customized_training(strategy,
trained_model = model_training_utils.run_customized_training_loop( trained_model = model_training_utils.run_customized_training_loop(
strategy=strategy, strategy=strategy,
model_fn=_get_pretrain_model, model_fn=_get_pretrain_model,
loss_fn=get_loss_fn( loss_fn=get_loss_fn(),
loss_factor=1.0 / scale_loss=FLAGS.scale_loss,
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0),
model_dir=model_dir, model_dir=model_dir,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
......
...@@ -90,8 +90,7 @@ FLAGS = flags.FLAGS ...@@ -90,8 +90,7 @@ FLAGS = flags.FLAGS
def squad_loss_fn(start_positions, def squad_loss_fn(start_positions,
end_positions, end_positions,
start_logits, start_logits,
end_logits, end_logits):
loss_factor=1.0):
"""Returns sparse categorical crossentropy for start/end logits.""" """Returns sparse categorical crossentropy for start/end logits."""
start_loss = tf.keras.losses.sparse_categorical_crossentropy( start_loss = tf.keras.losses.sparse_categorical_crossentropy(
start_positions, start_logits, from_logits=True) start_positions, start_logits, from_logits=True)
...@@ -99,11 +98,10 @@ def squad_loss_fn(start_positions, ...@@ -99,11 +98,10 @@ def squad_loss_fn(start_positions,
end_positions, end_logits, from_logits=True) end_positions, end_logits, from_logits=True)
total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2 total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
total_loss *= loss_factor
return total_loss return total_loss
def get_loss_fn(loss_factor=1.0): def get_loss_fn():
"""Gets a loss function for squad task.""" """Gets a loss function for squad task."""
def _loss_fn(labels, model_outputs): def _loss_fn(labels, model_outputs):
...@@ -114,8 +112,7 @@ def get_loss_fn(loss_factor=1.0): ...@@ -114,8 +112,7 @@ def get_loss_fn(loss_factor=1.0):
start_positions, start_positions,
end_positions, end_positions,
start_logits, start_logits,
end_logits, end_logits)
loss_factor=loss_factor)
return _loss_fn return _loss_fn
...@@ -249,14 +246,6 @@ def train_squad(strategy, ...@@ -249,14 +246,6 @@ def train_squad(strategy,
use_graph_rewrite=common_flags.use_graph_rewrite()) use_graph_rewrite=common_flags.use_graph_rewrite())
return squad_model, core_model return squad_model, core_model
# The original BERT model does not scale the loss by
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
loss_fn = get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
# If explicit_allreduce = True, apply_gradients() no longer implicitly # If explicit_allreduce = True, apply_gradients() no longer implicitly
# allreduce gradients, users manually allreduce gradient and pass the # allreduce gradients, users manually allreduce gradient and pass the
# allreduced grads_and_vars to apply_gradients(). clip_by_global_norm will be # allreduced grads_and_vars to apply_gradients(). clip_by_global_norm will be
...@@ -269,7 +258,7 @@ def train_squad(strategy, ...@@ -269,7 +258,7 @@ def train_squad(strategy,
model_training_utils.run_customized_training_loop( model_training_utils.run_customized_training_loop(
strategy=strategy, strategy=strategy,
model_fn=_get_squad_model, model_fn=_get_squad_model,
loss_fn=loss_fn, loss_fn=get_loss_fn(),
model_dir=FLAGS.model_dir, model_dir=FLAGS.model_dir,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
steps_per_loop=FLAGS.steps_per_loop, steps_per_loop=FLAGS.steps_per_loop,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment