"vscode:/vscode.git/clone" did not exist on "384e4471e98df3b782c1936a4da9ed3566a3f760"
Commit 25542676 authored by Jeremiah Harmsen's avatar Jeremiah Harmsen Committed by A. Unique TensorFlower
Browse files

Add network and BERT model to perform per-token classification (e.g., for...

Add network and BERT model to perform per-token classification (e.g., for named entity recognition tasks).

PiperOrigin-RevId: 311480326
parent ec3a4616
......@@ -8,6 +8,10 @@ models are intended as both convenience functions and canonical examples.
* [`BertClassifier`](bert_classifier.py) implements a simple classification
model containing a single classification head using the Classification network.
* [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token
classification model containing a single classification head using the
TokenClassification network.
* [`BertSpanLabeler`](bert_span_labeler.py) implementats a simple single-span
start-end predictor (that is, a model that predicts two values: a start token
index and an end token index), suitable for SQuAD-style tasks.
......
......@@ -16,3 +16,4 @@
from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import BertPretrainer
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
from official.nlp.modeling import networks
@tf.keras.utils.register_keras_serializable(package='Text')
class BertTokenClassifier(tf.keras.Model):
"""Token classifier model based on a BERT-style transformer-based encoder.
This is an implementation of the network structure surrounding a transformer
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805).
The BertTokenClassifier allows a user to pass in a transformer stack, and
instantiates a token classification network based on the passed `num_classes`
argument.
Arguments:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def __init__(self,
network,
num_classes,
initializer='glorot_uniform',
output='logits',
dropout_rate=0.1,
**kwargs):
self._self_setattr_tracking = False
self._config = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a handle to the network inputs for use
# when we construct the Model object at the end of init.
inputs = network.inputs
# Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model.
sequence_output, _ = network(inputs)
sequence_output = tf.keras.layers.Dropout(
rate=dropout_rate)(sequence_output)
self.classifier = networks.TokenClassification(
input_width=sequence_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
output=output,
name='classification')
predictions = self.classifier(sequence_output)
super(BertTokenClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs)
def get_config(self):
return self._config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BERT trainer network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling import networks
from official.nlp.modeling.models import bert_token_classifier
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class BertTokenClassifierTest(keras_parameterized.TestCase):
def test_bert_trainer(self):
"""Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer.
vocab_size = 100
sequence_length = 512
test_network = networks.TransformerEncoder(
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
# Create a BERT trainer with the created network.
num_classes = 3
bert_trainer_model = bert_token_classifier.BertTokenClassifier(
test_network,
num_classes=num_classes)
# Create a set of 2-dimensional inputs (the first dimension is implicit).
word_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
# Invoke the trainer model on the inputs. This causes the layer to be built.
sequence_outs = bert_trainer_model([word_ids, mask, type_ids])
# Validate that the outputs are of the expected shape.
expected_classification_shape = [None, sequence_length, num_classes]
self.assertAllEqual(expected_classification_shape,
sequence_outs.shape.as_list())
def test_bert_trainer_tensor_call(self):
"""Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=2)
# Create a BERT trainer with the created network.
bert_trainer_model = bert_token_classifier.BertTokenClassifier(
test_network, num_classes=2)
# Create a set of 2-dimensional data tensors to feed into the model.
word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
mask = tf.constant([[1, 1], [1, 0]], dtype=tf.int32)
type_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
# Invoke the trainer model on the tensors. In Eager mode, this does the
# actual calculation. (We can't validate the outputs, since the network is
# too complex: this simply ensures we're not hitting runtime errors.)
_ = bert_trainer_model([word_ids, mask, type_ids])
def test_serialize_deserialize(self):
"""Validate that the BERT trainer can be serialized and deserialized."""
# Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.)
test_network = networks.TransformerEncoder(
vocab_size=100, num_layers=2, sequence_length=5)
# Create a BERT trainer with the created network. (Note that all the args
# are different, so we can catch any serialization mismatches.)
bert_trainer_model = bert_token_classifier.BertTokenClassifier(
test_network, num_classes=4, initializer='zeros', output='predictions')
# Create another BERT trainer via serialization and deserialization.
config = bert_trainer_model.get_config()
new_bert_trainer_model = (
bert_token_classifier.BertTokenClassifier.from_config(config))
# Validate that the config can be forced to JSON.
_ = new_bert_trainer_model.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(bert_trainer_model.get_config(),
new_bert_trainer_model.get_config())
if __name__ == '__main__':
tf.test.main()
......@@ -20,5 +20,8 @@ into two smaller matrices and shares parameters across layers.
* [`Classification`](classification.py) contains a single hidden layer, and is intended for use as a classification head.
* [`TokenClassification`](token_classification.py) contains a single hidden
layer, and is intended for use as a token classification head.
* [`SpanLabeling`](span_labeling.py) implements a single-span labeler (that is, a prediction head that can predict one start and end index per batch item) based on a single dense hidden layer. It can be used in the SQuAD task.
......@@ -18,4 +18,5 @@ from official.nlp.modeling.networks.classification import Classification
from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
from official.nlp.modeling.networks.masked_lm import MaskedLM
from official.nlp.modeling.networks.span_labeling import SpanLabeling
from official.nlp.modeling.networks.token_classification import TokenClassification
from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classification network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Text')
class TokenClassification(tf.keras.Model):
"""TokenClassification network head for BERT modeling.
This network implements a simple token classifier head based on a dense layer.
Arguments:
input_width: The innermost dimension of the input tensor to this network.
num_classes: The number of classes that this network should classify to.
activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to
a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def __init__(self,
input_width,
num_classes,
initializer='glorot_uniform',
output='logits',
**kwargs):
self._self_setattr_tracking = False
self._config_dict = {
'input_width': input_width,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
sequence_data = tf.keras.layers.Input(
shape=(None, input_width), name='sequence_data', dtype=tf.float32)
self.logits = tf.keras.layers.Dense(
num_classes,
activation=None,
kernel_initializer=initializer,
name='predictions/transform/logits')(
sequence_data)
predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(self.logits)
if output == 'logits':
output_tensors = self.logits
elif output == 'predictions':
output_tensors = predictions
else:
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
super(TokenClassification, self).__init__(
inputs=[sequence_data], outputs=output_tensors, **kwargs)
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for token classification network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.networks import token_classification
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class TokenClassificationTest(keras_parameterized.TestCase):
def test_network_creation(self):
"""Validate that the Keras object can be created."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes)
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
# Validate that the outputs are of the expected shape.
expected_output_shape = [None, sequence_length, num_classes]
self.assertEqual(expected_output_shape, output.shape.as_list())
def test_network_invocation(self):
"""Validate that the Keras object can be invoked."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='predictions')
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
# Invoke the network as part of a Model.
model = tf.keras.Model(sequence_data, output)
input_data = 10 * np.random.random_sample((3, sequence_length, input_width))
_ = model.predict(input_data)
def test_network_invocation_with_internal_logits(self):
"""Validate that the logit outputs are correct."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='predictions')
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
model = tf.keras.Model(sequence_data, output)
logits_model = tf.keras.Model(test_object.inputs, test_object.logits)
batch_size = 3
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, input_width))
outputs = model.predict(input_data)
logits = logits_model.predict(input_data)
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, sequence_length, num_classes)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
# Ensure that the logits, when softmaxed, create the outputs.
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_network_invocation_with_internal_and_external_logits(self):
"""Validate that the logit outputs are correct."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='logits')
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
model = tf.keras.Model(sequence_data, output)
logits_model = tf.keras.Model(test_object.inputs, test_object.logits)
batch_size = 3
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, input_width))
outputs = model.predict(input_data)
logits = logits_model.predict(input_data)
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, sequence_length, num_classes)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
self.assertAllClose(outputs, logits)
def test_network_invocation_with_logit_output(self):
"""Validate that the logit outputs are correct."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='predictions')
logit_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='logits')
logit_object.set_weights(test_object.get_weights())
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
logit_output = logit_object(sequence_data)
model = tf.keras.Model(sequence_data, output)
logits_model = tf.keras.Model(sequence_data, logit_output)
batch_size = 3
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, input_width))
outputs = model.predict(input_data)
logits = logits_model.predict(input_data)
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, sequence_length, num_classes)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
# Ensure that the logits, when softmaxed, create the outputs.
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
network = token_classification.TokenClassification(
input_width=128,
num_classes=10,
initializer='zeros',
output='predictions')
# Create another network object from the first object's config.
new_network = token_classification.TokenClassification.from_config(
network.get_config())
# Validate that the config can be forced to JSON.
_ = new_network.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
def test_unknown_output_type_fails(self):
with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
_ = token_classification.TokenClassification(
input_width=128, num_classes=10, output='bad')
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment