Commit 250701c6 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Remove TokenClassication network which is a single dense layer.

PiperOrigin-RevId: 323711016
parent 5a1b5af3
...@@ -478,7 +478,7 @@ ...@@ -478,7 +478,7 @@
"source": [ "source": [
"### Build a BertClassifier model wrapping TransformerEncoder\n", "### Build a BertClassifier model wrapping TransformerEncoder\n",
"\n", "\n",
"[BertClassifier](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_classifier.py) implements a simple token classification model containing a single classification head using the `TokenClassification` network." "[BertClassifier](https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models/bert_classifier.py) implements a [CLS] token classification model containing a single classification head."
] ]
}, },
{ {
......
...@@ -10,8 +10,8 @@ model containing a single classification head using the Classification network. ...@@ -10,8 +10,8 @@ model containing a single classification head using the Classification network.
It can be used as a regression model as well. It can be used as a regression model as well.
* [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token * [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token
classification model containing a single classification head using the classification model containing a single classification head over the sequence
TokenClassification network. output embeddings.
* [`BertSpanLabeler`](bert_span_labeler.py) implementats a simple single-span * [`BertSpanLabeler`](bert_span_labeler.py) implementats a simple single-span
start-end predictor (that is, a model that predicts two values: a start token start-end predictor (that is, a model that predicts two values: a start token
......
...@@ -12,12 +12,8 @@ ...@@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """BERT cls-token classifier."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
......
...@@ -12,12 +12,8 @@ ...@@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """BERT Pre-training model."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import copy import copy
from typing import List, Optional from typing import List, Optional
......
...@@ -12,12 +12,8 @@ ...@@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """BERT Question Answering model."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
......
...@@ -12,17 +12,11 @@ ...@@ -12,17 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """BERT token classifier."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import networks
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class BertTokenClassifier(tf.keras.Model): class BertTokenClassifier(tf.keras.Model):
...@@ -77,16 +71,23 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -77,16 +71,23 @@ class BertTokenClassifier(tf.keras.Model):
sequence_output = tf.keras.layers.Dropout( sequence_output = tf.keras.layers.Dropout(
rate=dropout_rate)(sequence_output) rate=dropout_rate)(sequence_output)
self.classifier = networks.TokenClassification( self.classifier = tf.keras.layers.Dense(
input_width=sequence_output.shape[-1], num_classes,
num_classes=num_classes, activation=None,
initializer=initializer, kernel_initializer=initializer,
output=output, name='predictions/transform/logits')
name='classification') self.logits = self.classifier(sequence_output)
predictions = self.classifier(sequence_output) if output == 'logits':
output_tensors = self.logits
elif output == 'predictions':
output_tensors = tf.keras.layers.Activation(tf.nn.log_softmax)(
self.logits)
else:
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
super(BertTokenClassifier, self).__init__( super(BertTokenClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs) inputs=inputs, outputs=output_tensors, **kwargs)
@property @property
def checkpoint_items(self): def checkpoint_items(self):
......
...@@ -20,8 +20,5 @@ into two smaller matrices and shares parameters across layers. ...@@ -20,8 +20,5 @@ into two smaller matrices and shares parameters across layers.
intended for use as a classification or regression (if number of classes is set intended for use as a classification or regression (if number of classes is set
to 1) head. to 1) head.
* [`TokenClassification`](token_classification.py) contains a single hidden
layer, and is intended for use as a token classification head.
* [`SpanLabeling`](span_labeling.py) implements a single-span labeler (that is, a prediction head that can predict one start and end index per batch item) based on a single dense hidden layer. It can be used in the SQuAD task. * [`SpanLabeling`](span_labeling.py) implements a single-span labeler (that is, a prediction head that can predict one start and end index per batch item) based on a single dense hidden layer. It can be used in the SQuAD task.
...@@ -17,5 +17,4 @@ from official.nlp.modeling.networks.albert_transformer_encoder import AlbertTran ...@@ -17,5 +17,4 @@ from official.nlp.modeling.networks.albert_transformer_encoder import AlbertTran
from official.nlp.modeling.networks.classification import Classification from official.nlp.modeling.networks.classification import Classification
from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold from official.nlp.modeling.networks.encoder_scaffold import EncoderScaffold
from official.nlp.modeling.networks.span_labeling import SpanLabeling from official.nlp.modeling.networks.span_labeling import SpanLabeling
from official.nlp.modeling.networks.token_classification import TokenClassification
from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder from official.nlp.modeling.networks.transformer_encoder import TransformerEncoder
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classification network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='Text')
class TokenClassification(tf.keras.Model):
"""TokenClassification network head for BERT modeling.
This network implements a simple token classifier head based on a dense layer.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
input_width: The innermost dimension of the input tensor to this network.
num_classes: The number of classes that this network should classify to.
activation: The activation, if any, for the dense layer in this network.
initializer: The initializer for the dense layer in this network. Defaults
to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
"""
def __init__(self,
input_width,
num_classes,
initializer='glorot_uniform',
output='logits',
**kwargs):
self._self_setattr_tracking = False
self._config_dict = {
'input_width': input_width,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
}
sequence_data = tf.keras.layers.Input(
shape=(None, input_width), name='sequence_data', dtype=tf.float32)
self.logits = tf.keras.layers.Dense(
num_classes,
activation=None,
kernel_initializer=initializer,
name='predictions/transform/logits')(
sequence_data)
predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(self.logits)
if output == 'logits':
output_tensors = self.logits
elif output == 'predictions':
output_tensors = predictions
else:
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
super(TokenClassification, self).__init__(
inputs=[sequence_data], outputs=output_tensors, **kwargs)
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for token classification network."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from official.nlp.modeling.networks import token_classification
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class TokenClassificationTest(keras_parameterized.TestCase):
def test_network_creation(self):
"""Validate that the Keras object can be created."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes)
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
# Validate that the outputs are of the expected shape.
expected_output_shape = [None, sequence_length, num_classes]
self.assertEqual(expected_output_shape, output.shape.as_list())
def test_network_invocation(self):
"""Validate that the Keras object can be invoked."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='predictions')
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
# Invoke the network as part of a Model.
model = tf.keras.Model(sequence_data, output)
input_data = 10 * np.random.random_sample((3, sequence_length, input_width))
_ = model.predict(input_data)
def test_network_invocation_with_internal_logits(self):
"""Validate that the logit outputs are correct."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='predictions')
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
model = tf.keras.Model(sequence_data, output)
logits_model = tf.keras.Model(test_object.inputs, test_object.logits)
batch_size = 3
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, input_width))
outputs = model.predict(input_data)
logits = logits_model.predict(input_data)
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, sequence_length, num_classes)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
# Ensure that the logits, when softmaxed, create the outputs.
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_network_invocation_with_internal_and_external_logits(self):
"""Validate that the logit outputs are correct."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='logits')
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
model = tf.keras.Model(sequence_data, output)
logits_model = tf.keras.Model(test_object.inputs, test_object.logits)
batch_size = 3
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, input_width))
outputs = model.predict(input_data)
logits = logits_model.predict(input_data)
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, sequence_length, num_classes)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
self.assertAllClose(outputs, logits)
def test_network_invocation_with_logit_output(self):
"""Validate that the logit outputs are correct."""
sequence_length = 5
input_width = 512
num_classes = 10
test_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='predictions')
logit_object = token_classification.TokenClassification(
input_width=input_width, num_classes=num_classes, output='logits')
logit_object.set_weights(test_object.get_weights())
# Create a 3-dimensional input (the first dimension is implicit).
sequence_data = tf.keras.Input(shape=(sequence_length, input_width),
dtype=tf.float32)
output = test_object(sequence_data)
logit_output = logit_object(sequence_data)
model = tf.keras.Model(sequence_data, output)
logits_model = tf.keras.Model(sequence_data, logit_output)
batch_size = 3
input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, input_width))
outputs = model.predict(input_data)
logits = logits_model.predict(input_data)
# Ensure that the tensor shapes are correct.
expected_output_shape = (batch_size, sequence_length, num_classes)
self.assertEqual(expected_output_shape, outputs.shape)
self.assertEqual(expected_output_shape, logits.shape)
# Ensure that the logits, when softmaxed, create the outputs.
input_tensor = tf.keras.Input(expected_output_shape[1:])
output_tensor = tf.keras.layers.Activation(tf.nn.log_softmax)(input_tensor)
softmax_model = tf.keras.Model(input_tensor, output_tensor)
calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax)
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
network = token_classification.TokenClassification(
input_width=128,
num_classes=10,
initializer='zeros',
output='predictions')
# Create another network object from the first object's config.
new_network = token_classification.TokenClassification.from_config(
network.get_config())
# Validate that the config can be forced to JSON.
_ = new_network.to_json()
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(network.get_config(), new_network.get_config())
def test_unknown_output_type_fails(self):
with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
_ = token_classification.TokenClassification(
input_width=128, num_classes=10, output='bad')
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment