"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "3397eff0cd0fa921fb190631dd4423d2648712c6"
Commit d90bed2e authored by Maxim Neumann's avatar Maxim Neumann Committed by A. Unique TensorFlower
Browse files

Update documentation and tests of model and network for regression tasks.

PiperOrigin-RevId: 314486753
parent 46245913
...@@ -7,6 +7,7 @@ models are intended as both convenience functions and canonical examples. ...@@ -7,6 +7,7 @@ models are intended as both convenience functions and canonical examples.
* [`BertClassifier`](bert_classifier.py) implements a simple classification * [`BertClassifier`](bert_classifier.py) implements a simple classification
model containing a single classification head using the Classification network. model containing a single classification head using the Classification network.
It can be used as a regression model as well.
* [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token * [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token
classification model containing a single classification head using the classification model containing a single classification head using the
......
...@@ -34,7 +34,7 @@ class BertClassifier(tf.keras.Model): ...@@ -34,7 +34,7 @@ class BertClassifier(tf.keras.Model):
The BertClassifier allows a user to pass in a transformer stack, and The BertClassifier allows a user to pass in a transformer stack, and
instantiates a classification network based on the passed `num_classes` instantiates a classification network based on the passed `num_classes`
argument. argument. If `num_classes` is set to 1, a regression network is instantiated.
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
......
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
...@@ -30,7 +31,8 @@ from official.nlp.modeling.models import bert_classifier ...@@ -30,7 +31,8 @@ from official.nlp.modeling.models import bert_classifier
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class BertClassifierTest(keras_parameterized.TestCase): class BertClassifierTest(keras_parameterized.TestCase):
def test_bert_trainer(self): @parameterized.parameters(1, 3)
def test_bert_trainer(self, num_classes):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
# Build a transformer network to use within the BERT trainer. # Build a transformer network to use within the BERT trainer.
vocab_size = 100 vocab_size = 100
...@@ -39,7 +41,6 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -39,7 +41,6 @@ class BertClassifierTest(keras_parameterized.TestCase):
vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length) vocab_size=vocab_size, num_layers=2, sequence_length=sequence_length)
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
num_classes = 3
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
test_network, test_network,
num_classes=num_classes) num_classes=num_classes)
...@@ -56,7 +57,8 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -56,7 +57,8 @@ class BertClassifierTest(keras_parameterized.TestCase):
expected_classification_shape = [None, num_classes] expected_classification_shape = [None, num_classes]
self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list()) self.assertAllEqual(expected_classification_shape, cls_outs.shape.as_list())
def test_bert_trainer_tensor_call(self): @parameterized.parameters(1, 2)
def test_bert_trainer_tensor_call(self, num_classes):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
# Build a transformer network to use within the BERT trainer. (Here, we use # Build a transformer network to use within the BERT trainer. (Here, we use
# a short sequence_length for convenience.) # a short sequence_length for convenience.)
...@@ -65,7 +67,7 @@ class BertClassifierTest(keras_parameterized.TestCase): ...@@ -65,7 +67,7 @@ class BertClassifierTest(keras_parameterized.TestCase):
# Create a BERT trainer with the created network. # Create a BERT trainer with the created network.
bert_trainer_model = bert_classifier.BertClassifier( bert_trainer_model = bert_classifier.BertClassifier(
test_network, num_classes=2) test_network, num_classes=num_classes)
# Create a set of 2-dimensional data tensors to feed into the model. # Create a set of 2-dimensional data tensors to feed into the model.
word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32) word_ids = tf.constant([[1, 1], [2, 2]], dtype=tf.int32)
......
...@@ -18,7 +18,9 @@ into two smaller matrices and shares parameters across layers. ...@@ -18,7 +18,9 @@ into two smaller matrices and shares parameters across layers.
* [`MaskedLM`](masked_lm.py) implements a masked language model for BERT pretraining. It assumes that the network being passed has a `get_embedding_table()` method. * [`MaskedLM`](masked_lm.py) implements a masked language model for BERT pretraining. It assumes that the network being passed has a `get_embedding_table()` method.
* [`Classification`](classification.py) contains a single hidden layer, and is intended for use as a classification head. * [`Classification`](classification.py) contains a single hidden layer, and is
intended for use as a classification or regression (if number of classes is set
to 1) head.
* [`TokenClassification`](token_classification.py) contains a single hidden * [`TokenClassification`](token_classification.py) contains a single hidden
layer, and is intended for use as a token classification head. layer, and is intended for use as a token classification head.
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Classification network.""" """Classification and regression network."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -26,11 +26,13 @@ import tensorflow as tf ...@@ -26,11 +26,13 @@ import tensorflow as tf
class Classification(tf.keras.Model): class Classification(tf.keras.Model):
"""Classification network head for BERT modeling. """Classification network head for BERT modeling.
This network implements a simple classifier head based on a dense layer. This network implements a simple classifier head based on a dense layer. If
num_classes is one, it can be considered as a regression problem.
Arguments: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
num_classes: The number of classes that this network should classify to. num_classes: The number of classes that this network should classify to. If
equal to 1, a regression problem is assumed.
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to initializer: The intializer for the dense layer in this network. Defaults to
a Glorot uniform initializer. a Glorot uniform initializer.
......
...@@ -18,6 +18,7 @@ from __future__ import absolute_import ...@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -30,10 +31,10 @@ from official.nlp.modeling.networks import classification ...@@ -30,10 +31,10 @@ from official.nlp.modeling.networks import classification
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class ClassificationTest(keras_parameterized.TestCase): class ClassificationTest(keras_parameterized.TestCase):
def test_network_creation(self): @parameterized.parameters(1, 10)
def test_network_creation(self, num_classes):
"""Validate that the Keras object can be created.""" """Validate that the Keras object can be created."""
input_width = 512 input_width = 512
num_classes = 10
test_object = classification.Classification( test_object = classification.Classification(
input_width=input_width, num_classes=num_classes) input_width=input_width, num_classes=num_classes)
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
...@@ -44,10 +45,10 @@ class ClassificationTest(keras_parameterized.TestCase): ...@@ -44,10 +45,10 @@ class ClassificationTest(keras_parameterized.TestCase):
expected_output_shape = [None, num_classes] expected_output_shape = [None, num_classes]
self.assertEqual(expected_output_shape, output.shape.as_list()) self.assertEqual(expected_output_shape, output.shape.as_list())
def test_network_invocation(self): @parameterized.parameters(1, 10)
def test_network_invocation(self, num_classes):
"""Validate that the Keras object can be invoked.""" """Validate that the Keras object can be invoked."""
input_width = 512 input_width = 512
num_classes = 10
test_object = classification.Classification( test_object = classification.Classification(
input_width=input_width, num_classes=num_classes, output='predictions') input_width=input_width, num_classes=num_classes, output='predictions')
# Create a 2-dimensional input (the first dimension is implicit). # Create a 2-dimensional input (the first dimension is implicit).
...@@ -90,10 +91,11 @@ class ClassificationTest(keras_parameterized.TestCase): ...@@ -90,10 +91,11 @@ class ClassificationTest(keras_parameterized.TestCase):
calculated_softmax = softmax_model.predict(logits) calculated_softmax = softmax_model.predict(logits)
self.assertAllClose(outputs, calculated_softmax) self.assertAllClose(outputs, calculated_softmax)
def test_network_invocation_with_internal_and_external_logits(self): @parameterized.parameters(1, 10)
def test_network_invocation_with_internal_and_external_logits(self,
num_classes):
"""Validate that the logit outputs are correct.""" """Validate that the logit outputs are correct."""
input_width = 512 input_width = 512
num_classes = 10
test_object = classification.Classification( test_object = classification.Classification(
input_width=input_width, num_classes=num_classes, output='logits') input_width=input_width, num_classes=num_classes, output='logits')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment