Commit 9fb74a40 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Avoid computing mlm if inputs do not have masked_lm_positions.

PiperOrigin-RevId: 351275455
parent ec39ee1d
...@@ -244,9 +244,11 @@ class BertPretrainerV2(tf.keras.Model): ...@@ -244,9 +244,11 @@ class BertPretrainerV2(tf.keras.Model):
raise ValueError('encoder_network\'s output should be either a list ' raise ValueError('encoder_network\'s output should be either a list '
'or a dict, but got %s' % encoder_network_outputs) 'or a dict, but got %s' % encoder_network_outputs)
sequence_output = outputs['sequence_output'] sequence_output = outputs['sequence_output']
masked_lm_positions = inputs['masked_lm_positions'] # Inference may not have masked_lm_positions and mlm_logits is not needed.
outputs['mlm_logits'] = self.masked_lm( if 'masked_lm_positions' in inputs:
sequence_output, masked_positions=masked_lm_positions) masked_lm_positions = inputs['masked_lm_positions']
outputs['mlm_logits'] = self.masked_lm(
sequence_output, masked_positions=masked_lm_positions)
for cls_head in self.classification_heads: for cls_head in self.classification_heads:
cls_outputs = cls_head(sequence_output) cls_outputs = cls_head(sequence_output)
if isinstance(cls_outputs, dict): if isinstance(cls_outputs, dict):
......
...@@ -117,9 +117,10 @@ class BertPretrainerV2Test(keras_parameterized.TestCase): ...@@ -117,9 +117,10 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
(False, True), (False, True),
(False, True), (False, True),
(False, True), (False, True),
(False, True),
)) ))
def test_bert_pretrainerv2(self, dict_outputs, return_all_encoder_outputs, def test_bert_pretrainerv2(self, dict_outputs, return_all_encoder_outputs,
use_customized_masked_lm): use_customized_masked_lm, has_masked_lm_positions):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
...@@ -148,27 +149,27 @@ class BertPretrainerV2Test(keras_parameterized.TestCase): ...@@ -148,27 +149,27 @@ class BertPretrainerV2Test(keras_parameterized.TestCase):
inputs = dict( inputs = dict(
input_word_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32), input_word_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
input_mask=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32), input_mask=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32),
input_type_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32), input_type_ids=tf.keras.Input(shape=(sequence_length,), dtype=tf.int32))
masked_lm_positions=tf.keras.Input( if has_masked_lm_positions:
shape=(num_token_predictions,), dtype=tf.int32)) inputs['masked_lm_positions'] = tf.keras.Input(
shape=(num_token_predictions,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built. # Invoke the trainer model on the inputs. This causes the layer to be built.
outputs = bert_trainer_model(inputs) outputs = bert_trainer_model(inputs)
has_encoder_outputs = dict_outputs or return_all_encoder_outputs has_encoder_outputs = dict_outputs or return_all_encoder_outputs
expected_keys = ['sequence_output', 'pooled_output']
if has_encoder_outputs: if has_encoder_outputs:
self.assertSameElements( expected_keys.append('encoder_outputs')
outputs.keys(), if has_masked_lm_positions:
['sequence_output', 'pooled_output', 'mlm_logits', 'encoder_outputs']) expected_keys.append('mlm_logits')
self.assertLen(outputs['encoder_outputs'], num_layers)
else:
self.assertSameElements(
outputs.keys(), ['sequence_output', 'pooled_output', 'mlm_logits'])
self.assertSameElements(outputs.keys(), expected_keys)
# Validate that the outputs are of the expected shape. # Validate that the outputs are of the expected shape.
expected_lm_shape = [None, num_token_predictions, vocab_size] expected_lm_shape = [None, num_token_predictions, vocab_size]
self.assertAllEqual(expected_lm_shape, if has_masked_lm_positions:
outputs['mlm_logits'].shape.as_list()) self.assertAllEqual(expected_lm_shape,
outputs['mlm_logits'].shape.as_list())
expected_sequence_output_shape = [None, sequence_length, hidden_size] expected_sequence_output_shape = [None, sequence_length, hidden_size]
self.assertAllEqual(expected_sequence_output_shape, self.assertAllEqual(expected_sequence_output_shape,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment