Commit e569c04c authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Change the ELECTRA pretrainer output to a dict in order to be compatible with later ELECTRA task.

Contributed by mickeystroller

PiperOrigin-RevId: 318909403
parent 99c9752e
...@@ -116,6 +116,22 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -116,6 +116,22 @@ class ElectraPretrainer(tf.keras.Model):
units=1, kernel_initializer=mlm_initializer) units=1, kernel_initializer=mlm_initializer)
def call(self, inputs): def call(self, inputs):
"""ELECTRA forward pass.
Args:
inputs: A dict of all inputs, same as the standard BERT model.
Returns:
outputs: A dict of pretrainer model outputs, including
(1) lm_outputs: a [batch_size, num_token_predictions, vocab_size] tensor
indicating logits on masked positions.
(2) sentence_outputs: a [batch_size, num_classes] tensor indicating
logits for nsp task.
(3) disc_logits: a [batch_size, sequence_length] tensor indicating
logits for discriminator replaced token detection task.
(4) disc_label: a [batch_size, sequence_length] tensor indicating
target labels for discriminator replaced token detection task.
"""
input_word_ids = inputs['input_word_ids'] input_word_ids = inputs['input_word_ids']
input_mask = inputs['input_mask'] input_mask = inputs['input_mask']
input_type_ids = inputs['input_type_ids'] input_type_ids = inputs['input_type_ids']
...@@ -152,7 +168,14 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -152,7 +168,14 @@ class ElectraPretrainer(tf.keras.Model):
disc_logits = self.discriminator_head(disc_sequence_output) disc_logits = self.discriminator_head(disc_sequence_output)
disc_logits = tf.squeeze(disc_logits, axis=-1) disc_logits = tf.squeeze(disc_logits, axis=-1)
return lm_outputs, sentence_outputs, disc_logits, disc_label outputs = {
'lm_outputs': lm_outputs,
'sentence_outputs': sentence_outputs,
'disc_logits': disc_logits,
'disc_label': disc_label,
}
return outputs
def _get_fake_data(self, inputs, mlm_logits, duplicate=True): def _get_fake_data(self, inputs, mlm_logits, duplicate=True):
"""Generate corrupted data for discriminator. """Generate corrupted data for discriminator.
......
...@@ -69,7 +69,11 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -69,7 +69,11 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
} }
# Invoke the trainer model on the inputs. This causes the layer to be built. # Invoke the trainer model on the inputs. This causes the layer to be built.
lm_outs, cls_outs, disc_logits, disc_label = eletrca_trainer_model(inputs) outputs = eletrca_trainer_model(inputs)
lm_outs = outputs['lm_outputs']
cls_outs = outputs['sentence_outputs']
disc_logits = outputs['disc_logits']
disc_label = outputs['disc_label']
# Validate that the outputs are of the expected shape. # Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size] expected_lm_shape = [None, num_token_predictions, vocab_size]
...@@ -117,7 +121,7 @@ class ElectraPretrainerTest(keras_parameterized.TestCase): ...@@ -117,7 +121,7 @@ class ElectraPretrainerTest(keras_parameterized.TestCase):
# Invoke the trainer model on the tensors. In Eager mode, this does the # Invoke the trainer model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is # actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.) # too complex: this simply ensures we're not hitting runtime errors.)
_, _, _, _ = eletrca_trainer_model(inputs) _ = eletrca_trainer_model(inputs)
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
"""Validate that the ELECTRA trainer can be serialized and deserialized.""" """Validate that the ELECTRA trainer can be serialized and deserialized."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment