Commit 7cda51fa authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 339901099
parent e792d861
...@@ -36,6 +36,9 @@ class MaskedLMConfig(cfg.TaskConfig): ...@@ -36,6 +36,9 @@ class MaskedLMConfig(cfg.TaskConfig):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence') inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
]) ])
# TODO(b/154564893): Mathematically, scale_loss should be True.
# However, it works better with scale_loss being False.
scale_loss: bool = False
train_data: cfg.DataConfig = cfg.DataConfig() train_data: cfg.DataConfig = cfg.DataConfig()
validation_data: cfg.DataConfig = cfg.DataConfig() validation_data: cfg.DataConfig = cfg.DataConfig()
...@@ -161,12 +164,15 @@ class MaskedLMTask(base_task.Task): ...@@ -161,12 +164,15 @@ class MaskedLMTask(base_task.Task):
model_outputs=outputs, model_outputs=outputs,
metrics=metrics, metrics=metrics,
aux_losses=model.losses) aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the if self.task_config.scale_loss:
# optimizer. # Scales loss as the default gradients allreduce performs sum inside the
# TODO(b/154564893): enable loss scaling. # optimizer.
# scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
tvars = model.trainable_variables tvars = model.trainable_variables
grads = tape.gradient(loss, tvars) if self.task_config.scale_loss:
grads = tape.gradient(scaled_loss, tvars)
else:
grads = tape.gradient(loss, tvars)
optimizer.apply_gradients(list(zip(grads, tvars))) optimizer.apply_gradients(list(zip(grads, tvars)))
self.process_metrics(metrics, inputs, outputs) self.process_metrics(metrics, inputs, outputs)
return {self.loss: loss} return {self.loss: loss}
......
...@@ -28,6 +28,7 @@ class MLMTaskTest(tf.test.TestCase): ...@@ -28,6 +28,7 @@ class MLMTaskTest(tf.test.TestCase):
def test_task(self): def test_task(self):
config = masked_lm.MaskedLMConfig( config = masked_lm.MaskedLMConfig(
init_checkpoint=self.get_temp_dir(), init_checkpoint=self.get_temp_dir(),
scale_loss=True,
model=bert.PretrainerConfig( model=bert.PretrainerConfig(
encoder=encoders.EncoderConfig( encoder=encoders.EncoderConfig(
bert=encoders.BertEncoderConfig(vocab_size=30522, bert=encoders.BertEncoderConfig(vocab_size=30522,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment