Commit 8624cb23 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 362613992
parent d48574cb
...@@ -42,6 +42,9 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -42,6 +42,9 @@ class BertTokenClassifier(tf.keras.Model):
Defaults to a Glorot uniform initializer. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either `logits` or output: The output style for this network. Can be either `logits` or
`predictions`. `predictions`.
dropout_rate: The dropout probability of the token classification head.
output_encoder_outputs: Whether to include intermediate sequence output
in the final output.
""" """
def __init__(self, def __init__(self,
...@@ -50,6 +53,7 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -50,6 +53,7 @@ class BertTokenClassifier(tf.keras.Model):
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
dropout_rate=0.1, dropout_rate=0.1,
output_encoder_outputs=False,
**kwargs): **kwargs):
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
...@@ -74,14 +78,19 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -74,14 +78,19 @@ class BertTokenClassifier(tf.keras.Model):
name='predictions/transform/logits') name='predictions/transform/logits')
logits = classifier(sequence_output) logits = classifier(sequence_output)
if output == 'logits': if output == 'logits':
output_tensors = logits output_tensors = {'logits': logits}
elif output == 'predictions': elif output == 'predictions':
output_tensors = tf.keras.layers.Activation(tf.nn.log_softmax)(logits) output_tensors = {
'predictions': tf.keras.layers.Activation(tf.nn.log_softmax)(logits)
}
else: else:
raise ValueError( raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or ' ('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output) '"predictions"') % output)
if output_encoder_outputs:
output_tensors['encoder_outputs'] = sequence_output
# b/164516224 # b/164516224
# Once we've created the network using the Functional API, we call # Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model # super().__init__ as though we were invoking the Functional API Model
...@@ -98,6 +107,7 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -98,6 +107,7 @@ class BertTokenClassifier(tf.keras.Model):
'num_classes': num_classes, 'num_classes': num_classes,
'initializer': initializer, 'initializer': initializer,
'output': output, 'output': output,
'output_encoder_outputs': output_encoder_outputs
} }
# We are storing the config dict as a namedtuple here to ensure checkpoint # We are storing the config dict as a namedtuple here to ensure checkpoint
......
...@@ -27,22 +27,26 @@ from official.nlp.modeling.models import bert_token_classifier ...@@ -27,22 +27,26 @@ from official.nlp.modeling.models import bert_token_classifier
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class BertTokenClassifierTest(keras_parameterized.TestCase): class BertTokenClassifierTest(keras_parameterized.TestCase):
@parameterized.parameters(True, False) @parameterized.parameters((True, True), (False, False))
def test_bert_trainer(self, dict_outputs): def test_bert_trainer(self, dict_outputs, output_encoder_outputs):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
sequence_length = 512 sequence_length = 512
hidden_size = 768
test_network = networks.BertEncoder( test_network = networks.BertEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=2, num_layers=2,
max_sequence_length=sequence_length, max_sequence_length=sequence_length,
dict_outputs=dict_outputs) dict_outputs=dict_outputs,
hidden_size=hidden_size)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
num_classes = 3 num_classes = 3
bert_trainer_model = bert_token_classifier.BertTokenClassifier( bert_trainer_model = bert_token_classifier.BertTokenClassifier(
test_network, num_classes=num_classes) test_network,
num_classes=num_classes,
output_encoder_outputs=output_encoder_outputs)
# Create a set of 2-dimensional inputs (the first dimension is implicit). # Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
...@@ -50,12 +54,18 @@ class BertTokenClassifierTest(keras_parameterized.TestCase): ...@@ -50,12 +54,18 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built. # Invoke the trainer model on the inputs. This causes the layer to be built.
sequence_outs = bert_trainer_model([word_ids, mask, type_ids]) outputs = bert_trainer_model([word_ids, mask, type_ids])
if output_encoder_outputs:
logits = outputs['logits']
encoder_outputs = outputs['encoder_outputs']
self.assertAllEqual(encoder_outputs.shape.as_list(),
[None, sequence_length, hidden_size])
else:
logits = outputs['logits']
# Validate that the outputs are of the expected shape. # Validate that the outputs are of the expected shape.
expected_classification_shape = [None, sequence_length, num_classes] expected_classification_shape = [None, sequence_length, num_classes]
self.assertAllEqual(expected_classification_shape, self.assertAllEqual(expected_classification_shape, logits.shape.as_list())
sequence_outs.shape.as_list())
def test_bert_trainer_tensor_call(self): def test_bert_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
......
...@@ -160,6 +160,8 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -160,6 +160,8 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
prediction = classifier([word_ids, mask, type_ids]) prediction = classifier([word_ids, mask, type_ids])
if task == models.BertTokenClassifier:
prediction = prediction['logits']
self.assertAllEqual(prediction.shape.as_list(), prediction_shape) self.assertAllEqual(prediction.shape.as_list(), prediction_shape)
......
...@@ -98,13 +98,14 @@ class TaggingTask(base_task.Task): ...@@ -98,13 +98,14 @@ class TaggingTask(base_task.Task):
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.head_initializer_range), stddev=self.task_config.model.head_initializer_range),
dropout_rate=self.task_config.model.head_dropout, dropout_rate=self.task_config.model.head_dropout,
output='logits') output='logits',
output_encoder_outputs=True)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
model_outputs = tf.cast(model_outputs, tf.float32) logits = tf.cast(model_outputs['logits'], tf.float32)
masked_labels, masked_weights = _masked_labels_and_weights(labels) masked_labels, masked_weights = _masked_labels_and_weights(labels)
loss = tf.keras.losses.sparse_categorical_crossentropy( loss = tf.keras.losses.sparse_categorical_crossentropy(
masked_labels, model_outputs, from_logits=True) masked_labels, logits, from_logits=True)
numerator_loss = tf.reduce_sum(loss * masked_weights) numerator_loss = tf.reduce_sum(loss * masked_weights)
denominator_loss = tf.reduce_sum(masked_weights) denominator_loss = tf.reduce_sum(masked_weights)
loss = tf.math.divide_no_nan(numerator_loss, denominator_loss) loss = tf.math.divide_no_nan(numerator_loss, denominator_loss)
...@@ -139,7 +140,7 @@ class TaggingTask(base_task.Task): ...@@ -139,7 +140,7 @@ class TaggingTask(base_task.Task):
def inference_step(self, inputs, model: tf.keras.Model): def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step.""" """Performs the forward step."""
logits = model(inputs, training=False) logits = model(inputs, training=False)['logits']
return {'logits': logits, return {'logits': logits,
'predict_ids': tf.argmax(logits, axis=-1, output_type=tf.int32)} 'predict_ids': tf.argmax(logits, axis=-1, output_type=tf.int32)}
...@@ -156,7 +157,7 @@ class TaggingTask(base_task.Task): ...@@ -156,7 +157,7 @@ class TaggingTask(base_task.Task):
""" """
features, labels = inputs features, labels = inputs
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
loss = self.build_losses(labels=labels, model_outputs=outputs['logits']) loss = self.build_losses(labels=labels, model_outputs=outputs)
# Negative label ids are padding labels which should be ignored. # Negative label ids are padding labels which should be ignored.
real_label_index = tf.where(tf.greater_equal(labels, 0)) real_label_index = tf.where(tf.greater_equal(labels, 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment