"vscode:/vscode.git/clone" did not exist on "610d2d039aceaf217c4ce1195c6d9958f1ca6a99"
Commit a009f4fb authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

move collection trainable variables outside loop.

add a flag to control loss scaling.

PiperOrigin-RevId: 267091566
parent a85c40e3
...@@ -52,6 +52,10 @@ def define_common_bert_flags(): ...@@ -52,6 +52,10 @@ def define_common_bert_flags():
flags.DEFINE_boolean( flags.DEFINE_boolean(
'run_eagerly', False, 'run_eagerly', False,
'Run the model op by op without building a model function.') 'Run the model op by op without building a model function.')
flags.DEFINE_boolean(
'scale_loss', False,
'Whether to divide the loss by number of replica inside the per-replica '
'loss function.')
# Adds flags for mixed precision training. # Adds flags for mixed precision training.
flags_core.define_performance( flags_core.define_performance(
......
...@@ -231,6 +231,10 @@ def run_customized_training_loop( ...@@ -231,6 +231,10 @@ def run_customized_training_loop(
else: else:
train_summary_writer = None train_summary_writer = None
# De-dupes variables due to keras tracking issues.
training_vars = list({id(v): v for v in model.trainable_variables
}.values())
def _replicated_step(inputs): def _replicated_step(inputs):
"""Replicated training step.""" """Replicated training step."""
...@@ -241,14 +245,12 @@ def run_customized_training_loop( ...@@ -241,14 +245,12 @@ def run_customized_training_loop(
if use_float16: if use_float16:
scaled_loss = optimizer.get_scaled_loss(loss) scaled_loss = optimizer.get_scaled_loss(loss)
# De-dupes variables due to keras tracking issues.
tvars = list({id(v): v for v in model.trainable_variables}.values())
if use_float16: if use_float16:
scaled_grads = tape.gradient(scaled_loss, tvars) scaled_grads = tape.gradient(scaled_loss, training_vars)
grads = optimizer.get_unscaled_gradients(scaled_grads) grads = optimizer.get_unscaled_gradients(scaled_grads)
else: else:
grads = tape.gradient(loss, tvars) grads = tape.gradient(loss, training_vars)
optimizer.apply_gradients(zip(grads, tvars)) optimizer.apply_gradients(zip(grads, training_vars))
# For reporting, the metric takes the mean of losses. # For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss) train_loss_metric.update_state(loss)
for metric in train_metrics: for metric in train_metrics:
......
...@@ -61,7 +61,7 @@ common_flags.define_common_bert_flags() ...@@ -61,7 +61,7 @@ common_flags.define_common_bert_flags()
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def get_loss_fn(num_classes, loss_scale=1.0): def get_loss_fn(num_classes, loss_factor=1.0):
"""Gets the classification loss function.""" """Gets the classification loss function."""
def classification_loss_fn(labels, logits): def classification_loss_fn(labels, logits):
...@@ -73,7 +73,7 @@ def get_loss_fn(num_classes, loss_scale=1.0): ...@@ -73,7 +73,7 @@ def get_loss_fn(num_classes, loss_scale=1.0):
per_example_loss = -tf.reduce_sum( per_example_loss = -tf.reduce_sum(
tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1) tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
loss = tf.reduce_mean(per_example_loss) loss = tf.reduce_mean(per_example_loss)
loss *= loss_scale loss *= loss_factor
return loss return loss
return classification_loss_fn return classification_loss_fn
...@@ -118,7 +118,10 @@ def run_customized_training(strategy, ...@@ -118,7 +118,10 @@ def run_customized_training(strategy,
initial_lr, steps_per_epoch * epochs, warmup_steps) initial_lr, steps_per_epoch * epochs, warmup_steps)
return classifier_model, core_model return classifier_model, core_model
loss_fn = get_loss_fn(num_classes, loss_scale=1.0) loss_fn = get_loss_fn(
num_classes,
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
# Defines evaluation metrics function, which will create metrics in the # Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope. # correct device and strategy scope.
......
...@@ -94,11 +94,11 @@ def get_pretrain_input_data(input_file_pattern, seq_length, ...@@ -94,11 +94,11 @@ def get_pretrain_input_data(input_file_pattern, seq_length,
return _dataset_fn if use_dataset_fn else _dataset_fn() return _dataset_fn if use_dataset_fn else _dataset_fn()
def get_loss_fn(loss_scale=1.0): def get_loss_fn(loss_factor=1.0):
"""Returns loss function for BERT pretraining.""" """Returns loss function for BERT pretraining."""
def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args): def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
return tf.keras.backend.mean(losses) * loss_scale return tf.keras.backend.mean(losses) * loss_factor
return _bert_pretrain_loss_fn return _bert_pretrain_loss_fn
...@@ -132,7 +132,9 @@ def run_customized_training(strategy, ...@@ -132,7 +132,9 @@ def run_customized_training(strategy,
trained_model = model_training_utils.run_customized_training_loop( trained_model = model_training_utils.run_customized_training_loop(
strategy=strategy, strategy=strategy,
model_fn=_get_pretrain_model, model_fn=_get_pretrain_model,
loss_fn=get_loss_fn(), loss_fn=get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0),
model_dir=model_dir, model_dir=model_dir,
train_input_fn=train_input_fn, train_input_fn=train_input_fn,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
......
...@@ -232,7 +232,9 @@ def train_squad(strategy, ...@@ -232,7 +232,9 @@ def train_squad(strategy,
# 1/num_replicas_in_sync. It could be an accident. So, in order to use # 1/num_replicas_in_sync. It could be an accident. So, in order to use
# the same hyper parameter, we do the same thing here by keeping each # the same hyper parameter, we do the same thing here by keeping each
# replica loss as it is. # replica loss as it is.
loss_fn = get_loss_fn(loss_factor=1.0) loss_fn = get_loss_fn(
loss_factor=1.0 /
strategy.num_replicas_in_sync if FLAGS.scale_loss else 1.0)
use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu) use_remote_tpu = (FLAGS.strategy_type == 'tpu' and FLAGS.tpu)
model_training_utils.run_customized_training_loop( model_training_utils.run_customized_training_loop(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment