Commit d90bed2e authored by Maxim Neumann's avatar Maxim Neumann Committed by A. Unique TensorFlower
Browse files

Update documentation and tests of model and network for regression tasks.

PiperOrigin-RevId: 314486753
parent 46245913
...@@ -7,6 +7,7 @@ models are intended as both convenience functions and canonical examples. ...@@ -7,6 +7,7 @@ models are intended as both convenience functions and canonical examples.
* [`BertClassifier`](bert_classifier.py) implements a simple classification * [`BertClassifier`](bert_classifier.py) implements a simple classification
model containing a single classification head using the Classification network. model containing a single classification head using the Classification network.
It can be used as a regression model as well.
* [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token * [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token
classification model containing a single classification head using the classification model containing a single classification head using the
......
...@@ -34,7 +34,7 @@ class BertClassifier(tf.keras.Model): ...@@ -34,7 +34,7 @@ class BertClassifier(tf.keras.Model):
The BertClassifier allows a user to pass in a transformer stack, and The BertClassifier allows a user to pass in a transformer stack, and
instantiates a classification network based on the passed `num_classes` instantiates a classification network based on the passed `num_classes`
argument. argument. If `num_classes` is set to 1, a regression network is instantiated.
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
......
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
...@@ -30,7 +31,8 @@ from official.nlp.modeling.models import bert_classifier ...@@ -30,7 +31,8 @@ from official.nlp.modeling.models import bert_classifier
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class BertClassifierTest(keras_parameterized.TestCase): class BertClassifierTest(keras_parameterized.TestCase):
def test_bert_trainer(self): @parameterized.parameters(1, 3)
def test_bert_trainer(self, num_classes):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
...@@ -39,7 +41,6 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -39,7 +41,6 @@ class BertClassifierTest(keras_parameterized.TestCase):
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
num_classes = 3
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
test_network, test_network,
num_classes=num_classes) num_classes=num_classes)
...@@ -56,7 +57,8 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -56,7 +57,8 @@ class BertClassifierTest(keras_parameterized.TestCase):
expected_classification_shape = [None, num_classes] expected_classification_shape = [None, num_classes]
self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list()) self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list())
def test_bert_trainer_tensor_call(self): @parameterized.parameters(1, 2)
def test_bert_trainer_tensor_call(self, num_classes):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
...@@ -65,7 +67,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -65,7 +67,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
test_network, num_classes=2) test_network, num_classes=num_classes)
# Create a set of 2-dimensional data tensors to feed into the model. # Create a set of 2-dimensional data tensors to feed into the model.
word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32) word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
......
...@@ -18,7 +18,9 @@ into two smaller matrices and shares parameters across layers. ...@@ -18,7 +18,9 @@ into two smaller matrices and shares parameters across layers.
* [`MaskedLM`](masked_lm.py) implements a masked language model for BERT pretraining. It assumes that the network being passed has a `get_embedding_table()` method. * [`MaskedLM`](masked_lm.py) implements a masked language model for BERT pretraining. It assumes that the network being passed has a `get_embedding_table()` method.
* [`Classification`](classification.py) contains a single hidden layer, and is intended for use as a classification head. * [`Classification`](classification.py) contains a single hidden layer, and is
intended for use as a classification or regression (if number of classes is set
to 1) head.
* [`TokenClassification`](token_classification.py) contains a single hidden * [`TokenClassification`](token_classification.py) contains a single hidden
layer, and is intended for use as a token classification head. layer, and is intended for use as a token classification head.
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Classification network.""" """Classification and regression network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -26,11 +26,13 @@ import tensorflow as tf ...@@ -26,11 +26,13 @@ import tensorflow as tf
class Classification(tf.keras.Model): class Classification(tf.keras.Model):
"""Classification network head for BERT modeling. """Classification network head for BERT modeling.
This network implements a simple classifier head based on a dense layer. This network implements a simple classifier head based on a dense layer. If
num_classes is one, it can be considered as a regression problem.
Arguments: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
num_classes: The number of classes that this network should classify to. num_classes: The number of classes that this network should classify to. If
equal to 1, a regression problem is assumed.
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to initializer: The intializer for the dense layer in this network. Defaults to
a Glorot uniform initializer. a Glorot uniform initializer.
......
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -30,10 +31,10 @@ from official.nlp.modeling.networks import classification ...@@ -30,10 +31,10 @@ from official.nlp.modeling.networks import classification
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class ClassificationTest(keras_parameterized.TestCase): class ClassificationTest(keras_parameterized.TestCase):
def test_network_creation(self): @parameterized.parameters(1, 10)
def test_network_creation(self, num_classes):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
input_width = 512 input_width = 512
num_classes = 10
test_object = classification.Classification( test_object = classification.Classification(
input_width=input_width, num_classes=num_classes) input_width=input_width, num_classes=num_classes)
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
...@@ -44,10 +45,10 @@ class ClassificationTest(keras_parameterized.TestCase): ...@@ -44,10 +45,10 @@ class ClassificationTest(keras_parameterized.TestCase):
expected_output_shape = [None, num_classes] expected_output_shape = [None, num_classes]
self.assertEqual(expected_output_shape, output.shape.as_list()) self.assertEqual(expected_output_shape, output.shape.as_list())
def test_network_invocation(self): @parameterized.parameters(1, 10)
def test_network_invocation(self, num_classes):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
input_width = 512 input_width = 512
num_classes = 10
test_object = classification.Classification( test_object = classification.Classification(
input_width=input_width, num_classes=num_classes, output='predictions') input_width=input_width, num_classes=num_classes, output='predictions')
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
...@@ -90,10 +91,11 @@ class ClassificationTest(keras_parameterized.TestCase): ...@@ -90,10 +91,11 @@ class ClassificationTest(keras_parameterized.TestCase):
calculated_softmax = softmax_model.predict(logits) calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax) self.assertAllClose(outputs, calculated_softmax)
def test_network_invocation_with_internal_and_external_logits(self): @parameterized.parameters(1, 10)
def test_network_invocation_with_internal_and_external_logits(self,
num_classes):
"""Validate that the logit outputs are correct.""" """Validate that the logit outputs are correct."""
input_width = 512 input_width = 512
num_classes = 10
test_object = classification.Classification( test_object = classification.Classification(
input_width=input_width, num_classes=num_classes, output='logits') input_width=input_width, num_classes=num_classes, output='logits')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment