Commit 7cda51fa authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 339901099
parent e792d861
......@@ -36,6 +36,9 @@ class MaskedLMConfig(cfg.TaskConfig):
bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
])
# TODO(b/154564893): Mathematically, scale_loss should be True.
# However, it works better with scale_loss being False.
scale_loss: bool = False
train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig()
......@@ -161,12 +164,15 @@ class MaskedLMTask(base_task.Task):
model_outputs=outputs,
metrics=metrics,
aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# TODO(b/154564893): enable loss scaling.
# scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
if self.task_config.scale_loss:
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
tvars = model.trainable_variables
grads = tape.gradient(loss, tvars)
if self.task_config.scale_loss:
grads = tape.gradient(scaled_loss, tvars)
else:
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars)))
self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss}
......
......@@ -28,6 +28,7 @@ class MLMTaskTest(tf.test.TestCase):
def test_task(self):
config = masked_lm.MaskedLMConfig(
init_checkpoint=self.get_temp_dir(),
scale_loss=True,
model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment