Commit 7e47cd7b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Adds clear documentation: Functional/Subclass API used for each network/model.

PiperOrigin-RevId: 321591514
parent 982f457a
......@@ -37,6 +37,9 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes`
argument. If `num_classes` is set to 1, a regression network is instantiated.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
......
......@@ -41,6 +41,9 @@ class BertPretrainer(tf.keras.Model):
instantiates the masked language model and classification networks that are
used to create the training objectives.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
network: A transformer network. This network should output a sequence output
and a classification output.
......
......@@ -32,9 +32,12 @@ class BertSpanLabeler(tf.keras.Model):
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805).
The BertSpanLabeler allows a user to pass in a transformer stack, and
The BertSpanLabeler allows a user to pass in a transformer encoder, and
instantiates a span labeling network based on a single dense layer.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
......
......@@ -36,6 +36,9 @@ class BertTokenClassifier(tf.keras.Model):
instantiates a token classification network based on the passed `num_classes`
argument.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
......
......@@ -39,6 +39,9 @@ class ElectraPretrainer(tf.keras.Model):
model (at generator side) and classification networks (at discriminator side)
that are used to create the training objectives.
*Note* that the model is constructed by Keras Subclass API, where layers are
defined inside __init__ and call() implements the computation.
Arguments:
generator_network: A transformer network for generator, this network should
output a sequence output and an optional classification output.
......
......@@ -40,6 +40,8 @@ class AlbertTransformerEncoder(tf.keras.Model):
The default values for this object are taken from the ALBERT-Base
implementation described in the paper.
*Note* that the network is constructed by Keras Functional API.
Arguments:
vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width is
......
......@@ -29,6 +29,9 @@ class Classification(tf.keras.Model):
This network implements a simple classifier head based on a dense layer. If
num_classes is one, it can be considered as a regression problem.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
input_width: The innermost dimension of the input tensor to this network.
num_classes: The number of classes that this network should classify to. If
......
......@@ -49,6 +49,9 @@ class EncoderScaffold(tf.keras.Model):
If the hidden_cls is not overridden, a default transformer layer will be
instantiated.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
pooled_output_dim: The dimension of pooled output.
pooler_layer_initializer: The initializer for the classification
......
......@@ -27,6 +27,8 @@ class SpanLabeling(tf.keras.Model):
"""Span labeling network head for BERT modeling.
This network implements a simple single-span labeler based on a dense layer.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
input_width: The innermost dimension of the input tensor to this network.
......
......@@ -27,6 +27,8 @@ class TokenClassification(tf.keras.Model):
"""TokenClassification network head for BERT modeling.
This network implements a simple token classifier head based on a dense layer.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
input_width: The innermost dimension of the input tensor to this network.
......
......@@ -39,6 +39,9 @@ class TransformerEncoder(tf.keras.Model):
in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding".
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment