Commit a009f4fb authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

move collection trainable variables outside loop.

add a flag to control loss scaling.

PiperOrigin-RevId: 267091566
parent a85c40e3
......@@ -52,6 +52,10 @@ def define_common_bert_flags():
flags.DEFINE_boolean(
'run_eagerly', False,
'Run the model op by op without building a model function.')
flags.DEFINE_boolean(
'scale_loss', False,
'Whether to divide the loss by number of replica inside the per-replica '
'loss function.')
# Adds flags for mixed precision training.
flags_core.define_performance(
......
......@@ -231,6 +231,10 @@ def run_customized_training_loop(
else:
train_summary_writer = None
# De-dupes variables due to keras tracking issues.
training_vars = list({id(v): v for v in model.trainable_variables
}.values())
def _replicated_step(inputs):
"""Replicated training step."""
......@@ -241,14 +245,12 @@ def run_customized_training_loop(
if use_float16:
scaled_loss = optimizer.get_scaled_loss(loss)
# De-dupes variables due to keras tracking issues.
tvars = list({id(v): v for v in model.trainable_variables}.values())
if use_float16:
scaled_grads = tape.gradient(scaled_loss, tvars)
scaled_grads = tape.gradient(scaled_loss, training_vars)
grads = optimizer.get_unscaled_gradients(scaled_grads)
else:
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(zip(grads, tvars))
grads = tape.gradient(loss, training_vars)
optimizer.apply_gradients(zip(grads, training_vars))
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
for metric in train_metrics:
......
......@@ -61,7 +61,7 @@ common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS
def get_loss_fn(num_classes, loss_scale=1.0):
def get_loss_fn(num_classes, loss_factor=1.0):
"""Gets the classification loss function."""
def classification_loss_fn(labels, logits):
......@@ -73,7 +73,7 @@ def get_loss_fn(num_classes, loss_scale=1.0):
per_example_loss = -tf.reduce_sum(
tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss)
loss *= loss_scale
loss *= loss_factor
return loss
return classification_loss_fn
......@@ -118,7 +118,10 @@ def run_customized_training(strategy,
initial_lr, steps_per_epoch * epochs, warmup_steps)
return classifier_model, core_model
loss_fn = get_loss_fn(num_classes, loss_scale=1.0)
loss_fn = get_loss_fn(
num_classes,
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
......
......@@ -94,11 +94,11 @@ def get_pretrain_input_data(input_file_pattern, seq_length,
return _dataset_fn if use_dataset_fn else _dataset_fn()
def get_loss_fn(loss_scale=1.0):
def get_loss_fn(loss_factor=1.0):
"""Returns loss function for BERT pretraining."""
def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
return tf.keras.backend.mean(losses) * loss_scale
return tf.keras.backend.mean(losses) * loss_factor
return _bert_pretrain_loss_fn
......@@ -132,7 +132,9 @@ def run_customized_training(strategy,
trained_model = model_training_utils.run_customized_training_loop(
strategy=strategy,
model_fn=_get_pretrain_model,
loss_fn=get_loss_fn(),
loss_fn=get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0),
model_dir=model_dir,
train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch,
......
......@@ -232,7 +232,9 @@ def train_squad(strategy,
# 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is.
loss_fn = get_loss_fn(loss_factor=1.0)
loss_fn = get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
model_training_utils.run_customized_training_loop(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment