"git@developer.sourcefind.cn:OpenDAS/torchani.git" did not exist on "194e88ffcbfd73d51105c10753a473449cf065c9"
Commit 7e47cd7b authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Adds clear documentation: Functional/Subclass API used for each network/model.

PiperOrigin-RevId: 321591514
parent 982f457a
...@@ -37,6 +37,9 @@ class BertClassifier(tf.keras.Model): ...@@ -37,6 +37,9 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes` instantiates a classification network based on the passed `num_classes`
argument. If `num_classes` is set to 1, a regression network is instantiated. argument. If `num_classes` is set to 1, a regression network is instantiated.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
......
...@@ -41,6 +41,9 @@ class BertPretrainer(tf.keras.Model): ...@@ -41,6 +41,9 @@ class BertPretrainer(tf.keras.Model):
instantiates the masked language model and classification networks that are instantiates the masked language model and classification networks that are
used to create the training objectives. used to create the training objectives.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. and a classification output.
......
...@@ -32,9 +32,12 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -32,9 +32,12 @@ class BertSpanLabeler(tf.keras.Model):
encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
for Language Understanding" (https://arxiv.org/abs/1810.04805). for Language Understanding" (https://arxiv.org/abs/1810.04805).
The BertSpanLabeler allows a user to pass in a transformer stack, and The BertSpanLabeler allows a user to pass in a transformer encoder, and
instantiates a span labeling network based on a single dense layer. instantiates a span labeling network based on a single dense layer.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
......
...@@ -36,6 +36,9 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -36,6 +36,9 @@ class BertTokenClassifier(tf.keras.Model):
instantiates a token classification network based on the passed `num_classes` instantiates a token classification network based on the passed `num_classes`
argument. argument.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
......
...@@ -39,6 +39,9 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -39,6 +39,9 @@ class ElectraPretrainer(tf.keras.Model):
model (at generator side) and classification networks (at discriminator side) model (at generator side) and classification networks (at discriminator side)
that are used to create the training objectives. that are used to create the training objectives.
*Note* that the model is constructed by Keras Subclass API, where layers are
defined inside __init__ and call() implements the computation.
Arguments: Arguments:
generator_network: A transformer network for generator, this network should generator_network: A transformer network for generator, this network should
output a sequence output and an optional classification output. output a sequence output and an optional classification output.
......
...@@ -40,6 +40,8 @@ class AlbertTransformerEncoder(tf.keras.Model): ...@@ -40,6 +40,8 @@ class AlbertTransformerEncoder(tf.keras.Model):
The default values for this object are taken from the ALBERT-Base The default values for this object are taken from the ALBERT-Base
implementation described in the paper. implementation described in the paper.
*Note* that the network is constructed by Keras Functional API.
Arguments: Arguments:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width is embedding_width: The width of the word embeddings. If the embedding width is
......
...@@ -29,6 +29,9 @@ class Classification(tf.keras.Model): ...@@ -29,6 +29,9 @@ class Classification(tf.keras.Model):
This network implements a simple classifier head based on a dense layer. If This network implements a simple classifier head based on a dense layer. If
num_classes is one, it can be considered as a regression problem. num_classes is one, it can be considered as a regression problem.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
num_classes: The number of classes that this network should classify to. If num_classes: The number of classes that this network should classify to. If
......
...@@ -49,6 +49,9 @@ class EncoderScaffold(tf.keras.Model): ...@@ -49,6 +49,9 @@ class EncoderScaffold(tf.keras.Model):
If the hidden_cls is not overridden, a default transformer layer will be If the hidden_cls is not overridden, a default transformer layer will be
instantiated. instantiated.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
pooled_output_dim: The dimension of pooled output. pooled_output_dim: The dimension of pooled output.
pooler_layer_initializer: The initializer for the classification pooler_layer_initializer: The initializer for the classification
......
...@@ -27,6 +27,8 @@ class SpanLabeling(tf.keras.Model): ...@@ -27,6 +27,8 @@ class SpanLabeling(tf.keras.Model):
"""Span labeling network head for BERT modeling. """Span labeling network head for BERT modeling.
This network implements a simple single-span labeler based on a dense layer. This network implements a simple single-span labeler based on a dense layer.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
......
...@@ -27,6 +27,8 @@ class TokenClassification(tf.keras.Model): ...@@ -27,6 +27,8 @@ class TokenClassification(tf.keras.Model):
"""TokenClassification network head for BERT modeling. """TokenClassification network head for BERT modeling.
This network implements a simple token classifier head based on a dense layer. This network implements a simple token classifier head based on a dense layer.
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
......
...@@ -39,6 +39,9 @@ class TransformerEncoder(tf.keras.Model): ...@@ -39,6 +39,9 @@ class TransformerEncoder(tf.keras.Model):
in "BERT: Pre-training of Deep Bidirectional Transformers for Language in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding". Understanding".
*Note* that the network is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers. hidden_size: The size of the transformer hidden layers.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment