Commit 651677f5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Attributes->Arguments. Be consistent with keras style.

PiperOrigin-RevId: 298692558
parent 1ac65814
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based attention layer.""" """Keras-based attention layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -45,7 +45,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -45,7 +45,7 @@ class Attention(tf.keras.layers.Layer):
interpolated by these probabilities, then concatenated back to a single interpolated by these probabilities, then concatenated back to a single
tensor and returned. tensor and returned.
Attributes: Arguments:
num_heads: Number of attention heads. num_heads: Number of attention heads.
head_size: Size of each attention head. head_size: Size of each attention head.
dropout: Dropout probability. dropout: Dropout probability.
...@@ -186,7 +186,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -186,7 +186,7 @@ class Attention(tf.keras.layers.Layer):
class CachedAttention(Attention): class CachedAttention(Attention):
"""Attention layer with cache used for auto-agressive decoding. """Attention layer with cache used for auto-agressive decoding.
Attributes: Arguments:
num_heads: Number of attention heads. num_heads: Number of attention heads.
head_size: Size of each attention head. head_size: Size of each attention head.
**kwargs: Other keyword arguments inherit from `Attention` class. **kwargs: Other keyword arguments inherit from `Attention` class.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based einsum layer.""" """Keras-based einsum layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -30,7 +30,7 @@ class DenseEinsum(tf.keras.layers.Layer): ...@@ -30,7 +30,7 @@ class DenseEinsum(tf.keras.layers.Layer):
This layer can perform einsum calculations of arbitrary dimensionality. This layer can perform einsum calculations of arbitrary dimensionality.
Attributes: Arguments:
output_shape: Positive integer or tuple, dimensionality of the output space. output_shape: Positive integer or tuple, dimensionality of the output space.
num_summed_dimensions: The number of dimensions to sum over. Standard 2D num_summed_dimensions: The number of dimensions to sum over. Standard 2D
matmul should use 1, 3D matmul should use 2, and so forth. matmul should use 1, 3D matmul should use 2, and so forth.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based softmax layer with optional masking.""" """Keras-based softmax layer with optional masking."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -26,7 +26,7 @@ import tensorflow as tf ...@@ -26,7 +26,7 @@ import tensorflow as tf
class MaskedSoftmax(tf.keras.layers.Layer): class MaskedSoftmax(tf.keras.layers.Layer):
"""Performs a softmax with optional masking on a tensor. """Performs a softmax with optional masking on a tensor.
Attributes: Arguments:
mask_expansion_axes: Any axes that should be padded on the mask tensor. mask_expansion_axes: Any axes that should be padded on the mask tensor.
""" """
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based one-hot embedding layer.""" """Keras-based one-hot embedding layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -31,7 +31,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -31,7 +31,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
This layer uses either tf.gather or tf.one_hot to translate integer indices to This layer uses either tf.gather or tf.one_hot to translate integer indices to
float embeddings. float embeddings.
Attributes: Arguments:
vocab_size: Number of elements in the vocabulary. vocab_size: Number of elements in the vocabulary.
embedding_width: Output size of the embedding layer. embedding_width: Output size of the embedding layer.
initializer: The initializer to use for the embedding weights. Defaults to initializer: The initializer to use for the embedding weights. Defaults to
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based positional embedding layer.""" """Keras-based positional embedding layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -37,7 +37,7 @@ class PositionEmbedding(tf.keras.layers.Layer): ...@@ -37,7 +37,7 @@ class PositionEmbedding(tf.keras.layers.Layer):
can have a dynamic 1st dimension, while if `use_dynamic_slicing` is False the can have a dynamic 1st dimension, while if `use_dynamic_slicing` is False the
input size must be fixed. input size must be fixed.
Attributes: Arguments:
use_dynamic_slicing: Whether to use the dynamic slicing path. use_dynamic_slicing: Whether to use the dynamic slicing path.
max_sequence_length: The maximum size of the dynamic sequence. Only max_sequence_length: The maximum size of the dynamic sequence. Only
applicable if `use_dynamic_slicing` is True. applicable if `use_dynamic_slicing` is True.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based transformer block layer.""" """Keras-based transformer block layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -32,7 +32,7 @@ class Transformer(tf.keras.layers.Layer): ...@@ -32,7 +32,7 @@ class Transformer(tf.keras.layers.Layer):
This layer implements the Transformer from "Attention Is All You Need". This layer implements the Transformer from "Attention Is All You Need".
(https://arxiv.org/abs/1706.03762). (https://arxiv.org/abs/1706.03762).
Attributes: Arguments:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer. intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer. intermediate_activation: Activation for the intermediate layer.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Keras-based transformer scaffold layer.""" """Keras-based transformer scaffold layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -35,7 +35,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -35,7 +35,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
`attention_cfg`, in which case the scaffold will instantiate the class with `attention_cfg`, in which case the scaffold will instantiate the class with
the config, or pass a class instance to `attention_cls`. the config, or pass a class instance to `attention_cls`.
Attributes: Arguments:
num_attention_heads: Number of attention heads. num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer. intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer. intermediate_activation: Activation for the intermediate layer.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""ALBERT (https://arxiv.org/abs/1810.04805) text encoder network.""" """ALBERT (https://arxiv.org/abs/1810.04805) text encoder network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -41,7 +41,7 @@ class AlbertTransformerEncoder(network.Network): ...@@ -41,7 +41,7 @@ class AlbertTransformerEncoder(network.Network):
The default values for this object are taken from the ALBERT-Base The default values for this object are taken from the ALBERT-Base
implementation described in the paper. implementation described in the paper.
Attributes: Arguments:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width embedding_width: The width of the word embeddings. If the embedding width
is not equal to hidden size, embedding parameters will be factorized into is not equal to hidden size, embedding parameters will be factorized into
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -36,7 +36,7 @@ class BertClassifier(tf.keras.Model): ...@@ -36,7 +36,7 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes` instantiates a classification network based on the passed `num_classes`
argument. argument.
Attributes: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method. table via a "get_embedding_table" method.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -37,7 +37,7 @@ class BertPretrainer(tf.keras.Model): ...@@ -37,7 +37,7 @@ class BertPretrainer(tf.keras.Model):
instantiates the masked language model and classification networks that are instantiates the masked language model and classification networks that are
used to create the training objectives. used to create the training objectives.
Attributes: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method. table via a "get_embedding_table" method.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """Trainer network for BERT-style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -35,7 +35,7 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -35,7 +35,7 @@ class BertSpanLabeler(tf.keras.Model):
The BertSpanLabeler allows a user to pass in a transformer stack, and The BertSpanLabeler allows a user to pass in a transformer stack, and
instantiates a span labeling network based on a single dense layer. instantiates a span labeling network based on a single dense layer.
Attributes: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method. table via a "get_embedding_table" method.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Classification network.""" """Classification network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -31,7 +31,7 @@ class Classification(network.Network): ...@@ -31,7 +31,7 @@ class Classification(network.Network):
This network implements a simple classifier head based on a dense layer. This network implements a simple classifier head based on a dense layer.
Attributes: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
num_classes: The number of classes that this network should classify to. num_classes: The number of classes that this network should classify to.
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Transformer-based text encoder network.""" """Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -46,7 +46,7 @@ class EncoderScaffold(network.Network): ...@@ -46,7 +46,7 @@ class EncoderScaffold(network.Network):
If the hidden_cls is not overridden, a default transformer layer will be If the hidden_cls is not overridden, a default transformer layer will be
instantiated. instantiated.
Attributes: Arguments:
num_output_classes: The output size of the classification layer. num_output_classes: The output size of the classification layer.
classification_layer_initializer: The initializer for the classification classification_layer_initializer: The initializer for the classification
layer. layer.
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Masked language model network.""" """Masked language model network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -32,7 +32,7 @@ class MaskedLM(network.Network): ...@@ -32,7 +32,7 @@ class MaskedLM(network.Network):
This network implements a masked language model based on the provided network. This network implements a masked language model based on the provided network.
It assumes that the network being passed has a "get_embedding_table()" method. It assumes that the network being passed has a "get_embedding_table()" method.
Attributes: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
num_predictions: The number of predictions to make per sequence. num_predictions: The number of predictions to make per sequence.
source_network: The network with the embedding layer to use for the source_network: The network with the embedding layer to use for the
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Span labeling network.""" """Span labeling network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -31,7 +31,7 @@ class SpanLabeling(network.Network): ...@@ -31,7 +31,7 @@ class SpanLabeling(network.Network):
This network implements a simple single-span labeler based on a dense layer. This network implements a simple single-span labeler based on a dense layer.
Attributes: Arguments:
input_width: The innermost dimension of the input tensor to this network. input_width: The innermost dimension of the input tensor to this network.
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to initializer: The intializer for the dense layer in this network. Defaults to
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Transformer-based text encoder network.""" """Transformer-based text encoder network."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
...@@ -40,7 +40,7 @@ class TransformerEncoder(network.Network): ...@@ -40,7 +40,7 @@ class TransformerEncoder(network.Network):
in "BERT: Pre-training of Deep Bidirectional Transformers for Language in "BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding". Understanding".
Attributes: Arguments:
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
hidden_size: The size of the transformer hidden layers. hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers. num_layers: The number of transformer layers.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment