Commit 8624cb23 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 362613992
parent d48574cb
......@@ -42,6 +42,9 @@ class BertTokenClassifier(tf.keras.Model):
Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either `logits` or
`predictions`.
dropout_rate: The dropout probability of the token classification head.
output_encoder_outputs: Whether to include intermediate sequence output
in the final output.
"""
def __init__(self,
......@@ -50,6 +53,7 @@ class BertTokenClassifier(tf.keras.Model):
initializer='glorot_uniform',
output='logits',
dropout_rate=0.1,
output_encoder_outputs=False,
**kwargs):
# We want to use the inputs of the passed network as the inputs to this
......@@ -74,14 +78,19 @@ class BertTokenClassifier(tf.keras.Model):
name='predictions/transform/logits')
logits = classifier(sequence_output)
if output == 'logits':
output_tensors = logits
output_tensors = {'logits': logits}
elif output == 'predictions':
output_tensors = tf.keras.layers.Activation(tf.nn.log_softmax)(logits)
output_tensors = {
'predictions': tf.keras.layers.Activation(tf.nn.log_softmax)(logits)
}
else:
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
if output_encoder_outputs:
output_tensors['encoder_outputs'] = sequence_output
# b/164516224
# Once we've created the network using the Functional API, we call
# super().__init__ as though we were invoking the Functional API Model
......@@ -98,6 +107,7 @@ class BertTokenClassifier(tf.keras.Model):
'num_classes': num_classes,
'initializer': initializer,
'output': output,
'output_encoder_outputs': output_encoder_outputs
}
# We are storing the config dict as a namedtuple here to ensure checkpoint
......
......@@ -27,22 +27,26 @@ from official.nlp.modeling.models import bert_token_classifier
@keras_parameterized.run_all_keras_modes
class BertTokenClassifierTest(keras_parameterized.TestCase):
@parameterized.parameters(True, False)
def test_bert_trainer(self, dict_outputs):
@parameterized.parameters((True, True), (False, False))
def test_bert_trainer(self, dict_outputs, output_encoder_outputs):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
hidden_size = 768
test_network = networks.BertEncoder(
vocab_size=vocab_size,
num_layers=2,
max_sequence_length=sequence_length,
dict_outputs=dict_outputs)
dict_outputs=dict_outputs,
hidden_size=hidden_size)
# Create a BERT trainer with the created network.
num_classes = 3
bert_trainer_model = bert_token_classifier.BertTokenClassifier(
test_network, num_classes=num_classes)
test_network,
num_classes=num_classes,
output_encoder_outputs=output_encoder_outputs)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
......@@ -50,12 +54,18 @@ class BertTokenClassifierTest(keras_parameterized.TestCase):
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built.
sequence_outs = bert_trainer_model([word_ids, mask, type_ids])
outputs = bert_trainer_model([word_ids, mask, type_ids])
if output_encoder_outputs:
logits = outputs['logits']
encoder_outputs = outputs['encoder_outputs']
self.assertAllEqual(encoder_outputs.shape.as_list(),
[None, sequence_length, hidden_size])
else:
logits = outputs['logits']
# Validate that the outputs are of the expected shape.
expected_classification_shape = [None, sequence_length, num_classes]
self.assertAllEqual(expected_classification_shape,
sequence_outs.shape.as_list())
self.assertAllEqual(expected_classification_shape, logits.shape.as_list())
def test_bert_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked."""
......
......@@ -160,6 +160,8 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
prediction = classifier([word_ids, mask, type_ids])
if task == models.BertTokenClassifier:
prediction = prediction['logits']
self.assertAllEqual(prediction.shape.as_list(), prediction_shape)
......
......@@ -98,13 +98,14 @@ class TaggingTask(base_task.Task):
initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.head_initializer_range),
dropout_rate=self.task_config.model.head_dropout,
output='logits')
output='logits',
output_encoder_outputs=True)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
model_outputs = tf.cast(model_outputs, tf.float32)
logits = tf.cast(model_outputs['logits'], tf.float32)
masked_labels, masked_weights = _masked_labels_and_weights(labels)
loss = tf.keras.losses.sparse_categorical_crossentropy(
masked_labels, model_outputs, from_logits=True)
masked_labels, logits, from_logits=True)
numerator_loss = tf.reduce_sum(loss * masked_weights)
denominator_loss = tf.reduce_sum(masked_weights)
loss = tf.math.divide_no_nan(numerator_loss, denominator_loss)
......@@ -139,7 +140,7 @@ class TaggingTask(base_task.Task):
def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step."""
logits = model(inputs, training=False)
logits = model(inputs, training=False)['logits']
return {'logits': logits,
'predict_ids': tf.argmax(logits, axis=-1, output_type=tf.int32)}
......@@ -156,7 +157,7 @@ class TaggingTask(base_task.Task):
"""
features, labels = inputs
outputs = self.inference_step(features, model)
loss = self.build_losses(labels=labels, model_outputs=outputs['logits'])
loss = self.build_losses(labels=labels, model_outputs=outputs)
# Negative label ids are padding labels which should be ignored.
real_label_index = tf.where(tf.greater_equal(labels, 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment