Commit d48574cb authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Use backticks to denote code spans in nlp modeling docstrings

PiperOrigin-RevId: 362610475
parent 454f8be7
...@@ -25,10 +25,10 @@ def _large_compatible_negative(tensor_type): ...@@ -25,10 +25,10 @@ def _large_compatible_negative(tensor_type):
in this module (-1e9) cannot be represented using `tf.float16`. in this module (-1e9) cannot be represented using `tf.float16`.
Args: Args:
tensor_type: a dtype to determine the type. tensor_type: A dtype to determine the type.
Returns: Returns:
a large negative number. A large negative number.
""" """
if tensor_type == tf.float16: if tensor_type == tf.float16:
return tf.float16.min return tf.float16.min
......
...@@ -44,7 +44,7 @@ def _get_norm_layer(normalization_type='no_norm', name=None): ...@@ -44,7 +44,7 @@ def _get_norm_layer(normalization_type='no_norm', name=None):
Args: Args:
normalization_type: String. The type of normalization_type, only normalization_type: String. The type of normalization_type, only
'no_norm' and 'layer_norm' are supported. `no_norm` and `layer_norm` are supported.
name: Name for the norm layer. name: Name for the norm layer.
Returns: Returns:
...@@ -89,7 +89,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer): ...@@ -89,7 +89,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
output_embed_size: Embedding size for the final embedding output. output_embed_size: Embedding size for the final embedding output.
max_sequence_length: Maximum length of input sequence. max_sequence_length: Maximum length of input sequence.
normalization_type: String. The type of normalization_type, only normalization_type: String. The type of normalization_type, only
'no_norm' and 'layer_norm' are supported. `no_norm` and `layer_norm` are supported.
initializer: The initializer to use for the embedding weights and initializer: The initializer to use for the embedding weights and
linear projection weights. linear projection weights.
dropout_rate: Dropout rate. dropout_rate: Dropout rate.
...@@ -208,10 +208,10 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -208,10 +208,10 @@ class MobileBertTransformer(tf.keras.layers.Layer):
key_query_shared_bottleneck: Whether to share linear transformation for key_query_shared_bottleneck: Whether to share linear transformation for
keys and queries. keys and queries.
num_feedforward_networks: Number of stacked feed-forward networks. num_feedforward_networks: Number of stacked feed-forward networks.
normalization_type: The type of normalization_type, only 'no_norm' and normalization_type: The type of normalization_type, only `no_norm` and
'layer_norm' are supported. 'no_norm' represents the element-wise `layer_norm` are supported. `no_norm` represents the element-wise
linear transformation for the student model, as suggested by the linear transformation for the student model, as suggested by the
original MobileBERT paper. 'layer_norm' is used for the teacher model. original MobileBERT paper. `layer_norm` is used for the teacher model.
initializer: The initializer to use for the embedding weights and initializer: The initializer to use for the embedding weights and
linear projection weights. linear projection weights.
**kwargs: keyword arguments. **kwargs: keyword arguments.
...@@ -346,14 +346,16 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -346,14 +346,16 @@ class MobileBertTransformer(tf.keras.layers.Layer):
"""Implementes the forward pass. """Implementes the forward pass.
Args: Args:
input_tensor: Float tensor of shape [batch_size, seq_length, hidden_size]. input_tensor: Float tensor of shape
attention_mask: (optional) int32 tensor of shape [batch_size, seq_length, `(batch_size, seq_length, hidden_size)`.
seq_length], with 1 for positions that can be attended to and 0 in attention_mask: (optional) int32 tensor of shape
positions that should not be. `(batch_size, seq_length, seq_length)`, with 1 for positions that can
be attended to and 0 in positions that should not be.
return_attention_scores: If return attention score. return_attention_scores: If return attention score.
Returns: Returns:
layer_output: Float tensor of shape [batch_size, seq_length, hidden_size]. layer_output: Float tensor of shape
`(batch_size, seq_length, hidden_size)`.
attention_scores (Optional): Only when return_attention_scores is True. attention_scores (Optional): Only when return_attention_scores is True.
Raises: Raises:
...@@ -450,8 +452,8 @@ class MobileBertMaskedLM(tf.keras.layers.Layer): ...@@ -450,8 +452,8 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
activation: The activation, if any, for the dense layer. activation: The activation, if any, for the dense layer.
initializer: The initializer for the dense layer. Defaults to a Glorot initializer: The initializer for the dense layer. Defaults to a Glorot
uniform initializer. uniform initializer.
output: The output style for this layer. Can be either 'logits' or output: The output style for this layer. Can be either `logits` or
'predictions'. `predictions`.
**kwargs: keyword arguments. **kwargs: keyword arguments.
""" """
super(MobileBertMaskedLM, self).__init__(**kwargs) super(MobileBertMaskedLM, self).__init__(**kwargs)
...@@ -527,16 +529,16 @@ class MobileBertMaskedLM(tf.keras.layers.Layer): ...@@ -527,16 +529,16 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
Args: Args:
sequence_tensor: Sequence output of `BertModel` layer of shape sequence_tensor: Sequence output of `BertModel` layer of shape
(`batch_size`, `seq_length`, num_hidden) where num_hidden is number of `(batch_size, seq_length, num_hidden)` where `num_hidden` is number of
hidden units of `BertModel` layer. hidden units of `BertModel` layer.
positions: Positions ids of tokens in sequence to mask for pretraining positions: Positions ids of tokens in sequence to mask for pretraining
of with dimension (batch_size, num_predictions) where of with dimension `(batch_size, num_predictions)` where
`num_predictions` is maximum number of tokens to mask out and predict `num_predictions` is maximum number of tokens to mask out and predict
per each sequence. per each sequence.
Returns: Returns:
Masked out sequence tensor of shape (batch_size * num_predictions, Masked out sequence tensor of shape
num_hidden). `(batch_size * num_predictions, num_hidden)`.
""" """
sequence_shape = tf.shape(sequence_tensor) sequence_shape = tf.shape(sequence_tensor)
batch_size, seq_length = sequence_shape[0], sequence_shape[1] batch_size, seq_length = sequence_shape[0], sequence_shape[1]
......
...@@ -26,8 +26,8 @@ class VotingAttention(tf.keras.layers.Layer): ...@@ -26,8 +26,8 @@ class VotingAttention(tf.keras.layers.Layer):
"""Voting Attention layer. """Voting Attention layer.
Args: Args:
num_heads: the number of attention heads. num_heads: The number of attention heads.
head_size: per-head hidden size. head_size: Per-head hidden size.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases. bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels. kernel_regularizer: Regularizer for dense layer kernels.
...@@ -115,7 +115,7 @@ class MultiChannelAttention(tf.keras.layers.MultiHeadAttention): ...@@ -115,7 +115,7 @@ class MultiChannelAttention(tf.keras.layers.MultiHeadAttention):
context tensors according to the distribution among channels. context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case. `value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention attention_mask: A boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions. to certain positions.
""" """
......
...@@ -77,7 +77,7 @@ class RelativePositionEmbedding(tf.keras.layers.Layer): ...@@ -77,7 +77,7 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
dimension of `inputs`. dimension of `inputs`.
Returns: Returns:
A tensor in shape of [length, hidden_size]. A tensor in shape of `(length, hidden_size)`.
""" """
if inputs is None and length is None: if inputs is None and length is None:
raise ValueError("If inputs is None, `length` must be set in " raise ValueError("If inputs is None, `length` must be set in "
...@@ -114,7 +114,7 @@ def _relative_position_bucket(relative_position, ...@@ -114,7 +114,7 @@ def _relative_position_bucket(relative_position,
the distance in tokens from the attending position to the attended-to the distance in tokens from the attending position to the attended-to
position. position.
If bidirectional=False, then positive relative positions are invalid. If `bidirectional=False`, then positive relative positions are invalid.
We use smaller buckets for small absolute relative_position and larger We use smaller buckets for small absolute relative_position and larger
buckets for larger absolute relative_positions. buckets for larger absolute relative_positions.
...@@ -127,13 +127,13 @@ def _relative_position_bucket(relative_position, ...@@ -127,13 +127,13 @@ def _relative_position_bucket(relative_position,
than the model has been trained on. than the model has been trained on.
Args: Args:
relative_position: an int32 Tensor relative_position: An int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional bidirectional: A boolean - whether the attention is bidirectional
num_buckets: an integer num_buckets: An integer
max_distance: an integer max_distance: An integer
Returns: Returns:
a Tensor with the same shape as relative_position, containing int32 A Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets) values in the range [0, num_buckets)
""" """
ret = 0 ret = 0
......
...@@ -103,10 +103,10 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention): ...@@ -103,10 +103,10 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
segment_attention_bias: Optional trainable bias parameter added to the segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in query had when calculating the segment-based attention score used in
XLNet of shape `[num_heads, dim]`. XLNet of shape `[num_heads, dim]`.
state: Optional `Tensor` of shape [B, M, E] where M is the length of the state: Optional `Tensor` of shape `[B, M, E]` where M is the length of the
state or memory. state or memory.
If passed, this is also attended over as in Transformer XL. If passed, this is also attended over as in Transformer XL.
attention_mask: a boolean mask of shape `[B, T, S]` that prevents attention attention_mask: A boolean mask of shape `[B, T, S]` that prevents attention
to certain positions. to certain positions.
""" """
......
...@@ -21,15 +21,15 @@ from official.nlp.keras_nlp import layers ...@@ -21,15 +21,15 @@ from official.nlp.keras_nlp import layers
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class SelfAttentionMask(layers.SelfAttentionMask): class SelfAttentionMask(layers.SelfAttentionMask):
"""Create 3D attention mask from a 2D tensor mask. """Creates 3D attention mask from a 2D tensor mask.
**Warning: Please use the `keras_nlp.layers.SelfAttentionMask`.** **Warning: Please use the `keras_nlp.layers.SelfAttentionMask`.**
inputs[0]: from_tensor: 2D or 3D Tensor of shape inputs[0]: from_tensor: 2D or 3D Tensor of shape
[batch_size, from_seq_length, ...]. `(batch_size, from_seq_length, ...)`.
inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length]. inputs[1]: to_mask: int32 Tensor of shape `(batch_size, to_seq_length)`.
Returns: Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length]. Float Tensor of shape `(batch_size, from_seq_length, to_seq_length)`.
""" """
def call(self, inputs): def call(self, inputs):
......
...@@ -63,7 +63,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention): ...@@ -63,7 +63,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
that will be applied on attention scores before and after softmax. that will be applied on attention scores before and after softmax.
Args: Args:
qkv_rank: the rank of query, key, value tensors after projection. qkv_rank: The rank of query, key, value tensors after projection.
""" """
super(TalkingHeadsAttention, self)._build_attention(qkv_rank) super(TalkingHeadsAttention, self)._build_attention(qkv_rank)
......
...@@ -100,10 +100,10 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -100,10 +100,10 @@ class BertTokenizer(tf.keras.layers.Layer):
tokenize_with_offsets: If true, calls tokenize_with_offsets: If true, calls
`text.BertTokenizer.tokenize_with_offsets()` instead of plain `text.BertTokenizer.tokenize_with_offsets()` instead of plain
`text.BertTokenizer.tokenize()` and outputs a triple of `text.BertTokenizer.tokenize()` and outputs a triple of
(tokens, start_offsets, limit_offsets). `(tokens, start_offsets, limit_offsets)`.
raw_table_access: An object with methods .lookup(keys) and .size() raw_table_access: An object with methods `.lookup(keys) and `.size()`
that operate on the raw lookup table of tokens. It can be used to that operate on the raw lookup table of tokens. It can be used to
look up special token synbols like [MASK]. look up special token synbols like `[MASK]`.
""" """
def __init__(self, *, def __init__(self, *,
...@@ -121,16 +121,16 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -121,16 +121,16 @@ class BertTokenizer(tf.keras.layers.Layer):
lower_case: A Python boolean forwarded to `text.BertTokenizer`. lower_case: A Python boolean forwarded to `text.BertTokenizer`.
If true, input text is converted to lower case (where applicable) If true, input text is converted to lower case (where applicable)
before tokenization. This must be set to match the way in which before tokenization. This must be set to match the way in which
the vocab_file was created. the `vocab_file` was created.
tokenize_with_offsets: A Python boolean. If true, this layer calls tokenize_with_offsets: A Python boolean. If true, this layer calls
`text.BertTokenizer.tokenize_with_offsets()` instead of plain `text.BertTokenizer.tokenize_with_offsets()` instead of plain
`text.BertTokenizer.tokenize()` and outputs a triple of `text.BertTokenizer.tokenize()` and outputs a triple of
(tokens, start_offsets, limit_offsets) `(tokens, start_offsets, limit_offsets)`
insead of just tokens. insead of just tokens.
**kwargs: standard arguments to Layer(). **kwargs: Standard arguments to `Layer()`.
Raises: Raises:
ImportError: if importing `tensorflow_text` failed. ImportError: If importing `tensorflow_text` failed.
""" """
_check_if_tf_text_installed() _check_if_tf_text_installed()
...@@ -167,17 +167,18 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -167,17 +167,18 @@ class BertTokenizer(tf.keras.layers.Layer):
"""Calls `text.BertTokenizer` on inputs. """Calls `text.BertTokenizer` on inputs.
Args: Args:
inputs: A string Tensor of shape [batch_size]. inputs: A string Tensor of shape `(batch_size,)`.
Returns: Returns:
One or three of `RaggedTensors` if `tokenize_with_offsets` is False or One or three of `RaggedTensors` if `tokenize_with_offsets` is False or
True, respectively. These are True, respectively. These are
tokens: A `RaggedTensor` of shape [batch_size, (words), (pieces_per_word)] tokens: A `RaggedTensor` of shape
and type int32. tokens[i,j,k] contains the k-th wordpiece of the `[batch_size, (words), (pieces_per_word)]`
and type int32. `tokens[i,j,k]` contains the k-th wordpiece of the
j-th word in the i-th input. j-th word in the i-th input.
start_offsets, limit_offsets: If `tokenize_with_offsets` is True, start_offsets, limit_offsets: If `tokenize_with_offsets` is True,
RaggedTensors of type int64 with the same indices as tokens. RaggedTensors of type int64 with the same indices as tokens.
Element [i,j,k] contains the byte offset at the start, or past the Element `[i,j,k]` contains the byte offset at the start, or past the
end, resp., for the k-th wordpiece of the j-th word in the i-th input. end, resp., for the k-th wordpiece of the j-th word in the i-th input.
""" """
# Prepare to reshape the result to work around broken shape inference. # Prepare to reshape the result to work around broken shape inference.
...@@ -268,13 +269,13 @@ class BertTokenizer(tf.keras.layers.Layer): ...@@ -268,13 +269,13 @@ class BertTokenizer(tf.keras.layers.Layer):
class SentencepieceTokenizer(tf.keras.layers.Layer): class SentencepieceTokenizer(tf.keras.layers.Layer):
"""Wraps tf_text.SentencepieceTokenizer as a Keras Layer. """Wraps `tf_text.SentencepieceTokenizer` as a Keras Layer.
Attributes: Attributes:
tokenize_with_offsets: If true, calls tokenize_with_offsets: If true, calls
SentencepieceTokenizer.tokenize_with_offsets() `SentencepieceTokenizer.tokenize_with_offsets()`
instead of plain .tokenize() and outputs a triple of instead of plain `.tokenize()` and outputs a triple of
(tokens, start_offsets, limit_offsets). `(tokens, start_offsets, limit_offsets)`.
""" """
def __init__(self, def __init__(self,
...@@ -300,9 +301,9 @@ class SentencepieceTokenizer(tf.keras.layers.Layer): ...@@ -300,9 +301,9 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
store the actual proto (not a filename passed here). store the actual proto (not a filename passed here).
model_serialized_proto: The sentencepiece model serialized proto string. model_serialized_proto: The sentencepiece model serialized proto string.
tokenize_with_offsets: A Python boolean. If true, this layer calls tokenize_with_offsets: A Python boolean. If true, this layer calls
SentencepieceTokenizer.tokenize_with_offsets() instead of `SentencepieceTokenizer.tokenize_with_offsets()` instead of
plain .tokenize() and outputs a triple of plain `.tokenize()` and outputs a triple of
(tokens, start_offsets, limit_offsets) insead of just tokens. `(tokens, start_offsets, limit_offsets)` insead of just tokens.
Note that when following `strip_diacritics` is set to True, returning Note that when following `strip_diacritics` is set to True, returning
offsets is not supported now. offsets is not supported now.
nbest_size: A scalar for sampling: nbest_size: A scalar for sampling:
...@@ -320,7 +321,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer): ...@@ -320,7 +321,7 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
`tokenize_with_offsets`. NOTE: New models are encouraged to put this `tokenize_with_offsets`. NOTE: New models are encouraged to put this
into custom normalization rules for the Sentencepiece model itself to into custom normalization rules for the Sentencepiece model itself to
avoid this extra step and the limitation regarding offsets. avoid this extra step and the limitation regarding offsets.
**kwargs: standard arguments to Layer(). **kwargs: standard arguments to `Layer()`.
Raises: Raises:
ImportError: if importing tensorflow_text failed. ImportError: if importing tensorflow_text failed.
...@@ -360,19 +361,19 @@ class SentencepieceTokenizer(tf.keras.layers.Layer): ...@@ -360,19 +361,19 @@ class SentencepieceTokenizer(tf.keras.layers.Layer):
return self._tokenizer.vocab_size() return self._tokenizer.vocab_size()
def call(self, inputs: tf.Tensor): def call(self, inputs: tf.Tensor):
"""Calls text.SentencepieceTokenizer on inputs. """Calls `text.SentencepieceTokenizer` on inputs.
Args: Args:
inputs: A string Tensor of shape [batch_size]. inputs: A string Tensor of shape `(batch_size,)`.
Returns: Returns:
One or three of RaggedTensors if tokenize_with_offsets is False or True, One or three of RaggedTensors if tokenize_with_offsets is False or True,
respectively. These are respectively. These are
tokens: A RaggedTensor of shape [batch_size, (pieces)] and type int32. tokens: A RaggedTensor of shape `[batch_size, (pieces)]` and type `int32`.
tokens[i,j] contains the j-th piece in the i-th input. `tokens[i,j]` contains the j-th piece in the i-th input.
start_offsets, limit_offsets: If tokenize_with_offsets is True, start_offsets, limit_offsets: If `tokenize_with_offsets` is True,
RaggedTensors of type int64 with the same indices as tokens. RaggedTensors of type `int64` with the same indices as tokens.
Element [i,j] contains the byte offset at the start, or past the Element `[i,j]` contains the byte offset at the start, or past the
end, resp., for the j-th piece in the i-th input. end, resp., for the j-th piece in the i-th input.
""" """
if self._strip_diacritics: if self._strip_diacritics:
...@@ -492,7 +493,7 @@ class BertPackInputs(tf.keras.layers.Layer): ...@@ -492,7 +493,7 @@ class BertPackInputs(tf.keras.layers.Layer):
special_tokens_dict=None, special_tokens_dict=None,
truncator="round_robin", truncator="round_robin",
**kwargs): **kwargs):
"""Initializes with a target seq_length, relevant token ids and truncator. """Initializes with a target `seq_length`, relevant token ids and truncator.
Args: Args:
seq_length: The desired output length. Must not exceed the max_seq_length seq_length: The desired output length. Must not exceed the max_seq_length
...@@ -505,13 +506,13 @@ class BertPackInputs(tf.keras.layers.Layer): ...@@ -505,13 +506,13 @@ class BertPackInputs(tf.keras.layers.Layer):
unused positions after the last segment in the sequence unused positions after the last segment in the sequence
(called "[PAD]" for BERT). (called "[PAD]" for BERT).
special_tokens_dict: Optionally, a dict from Python strings to Python special_tokens_dict: Optionally, a dict from Python strings to Python
integers that contains values for start_of_sequence_id, integers that contains values for `start_of_sequence_id`,
end_of_segment_id and padding_id. (Further values in the dict are `end_of_segment_id` and `padding_id`. (Further values in the dict are
silenty ignored.) If this is passed, separate *_id arguments must be silenty ignored.) If this is passed, separate *_id arguments must be
omitted. omitted.
truncator: The algorithm to truncate a list of batched segments to fit a truncator: The algorithm to truncate a list of batched segments to fit a
per-example length limit. The value can be either "round_robin" or per-example length limit. The value can be either `round_robin` or
"waterfall": `waterfall`:
(1) For "round_robin" algorithm, available space is assigned (1) For "round_robin" algorithm, available space is assigned
one token at a time in a round-robin fashion to the inputs that still one token at a time in a round-robin fashion to the inputs that still
need some, until the limit is reached. It currently only supports need some, until the limit is reached. It currently only supports
...@@ -521,10 +522,10 @@ class BertPackInputs(tf.keras.layers.Layer): ...@@ -521,10 +522,10 @@ class BertPackInputs(tf.keras.layers.Layer):
left-to-right manner and fills up the buckets until we run out of left-to-right manner and fills up the buckets until we run out of
budget. It support arbitrary number of segments. budget. It support arbitrary number of segments.
**kwargs: standard arguments to Layer(). **kwargs: standard arguments to `Layer()`.
Raises: Raises:
ImportError: if importing tensorflow_text failed. ImportError: if importing `tensorflow_text` failed.
""" """
_check_if_tf_text_installed() _check_if_tf_text_installed()
super().__init__(**kwargs) super().__init__(**kwargs)
......
...@@ -37,8 +37,8 @@ class TNExpandCondense(Layer): ...@@ -37,8 +37,8 @@ class TNExpandCondense(Layer):
Note the input shape and output shape will be identical. Note the input shape and output shape will be identical.
Args: Args:
proj_multiplier: Positive integer, multiple of input_shape[-1] to project proj_multiplier: Positive integer, multiple of `input_shape[-1]` to project
up to. Must be one of [2, 4, 6, 8]. up to. Must be one of `[2, 4, 6, 8]`.
use_bias: Boolean, whether the layer uses a bias vector. use_bias: Boolean, whether the layer uses a bias vector.
activation: Activation function to use between Expand and Condense. If you activation: Activation function to use between Expand and Condense. If you
don't specify anything, no activation is applied don't specify anything, no activation is applied
......
...@@ -50,8 +50,8 @@ class BertPretrainer(tf.keras.Model): ...@@ -50,8 +50,8 @@ class BertPretrainer(tf.keras.Model):
None, no activation will be used. None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer. classification networks. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
""" """
def __init__(self, def __init__(self,
......
...@@ -37,11 +37,11 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -37,11 +37,11 @@ class BertSpanLabeler(tf.keras.Model):
Args: Args:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method. table via a `get_embedding_table` method.
initializer: The initializer (if any) to use in the span labeling network. initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logit`' or
'predictions'. `predictions`.
""" """
def __init__(self, def __init__(self,
......
...@@ -36,12 +36,12 @@ class BertTokenClassifier(tf.keras.Model): ...@@ -36,12 +36,12 @@ class BertTokenClassifier(tf.keras.Model):
Args: Args:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method. table via a `get_embedding_table` method.
num_classes: Number of classes to predict from the classification network. num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks. initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
""" """
def __init__(self, def __init__(self,
......
...@@ -37,8 +37,8 @@ class DualEncoder(tf.keras.Model): ...@@ -37,8 +37,8 @@ class DualEncoder(tf.keras.Model):
normalize: If set to True, normalize the encoding produced by transfomer. normalize: If set to True, normalize the encoding produced by transfomer.
logit_scale: The scaling factor of dot products when doing training. logit_scale: The scaling factor of dot products when doing training.
logit_margin: The margin between positive and negative when doing training. logit_margin: The margin between positive and negative when doing training.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. If set to 'predictions', it will output the embedding `predictions`. If set to `predictions`, it will output the embedding
producted by transformer network. producted by transformer network.
""" """
......
...@@ -52,8 +52,8 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -52,8 +52,8 @@ class ElectraPretrainer(tf.keras.Model):
classification networks. If None, no activation will be used. classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer. classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either 'logits' or output_type: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
disallow_correct: Whether to disallow the generator to generate the exact disallow_correct: Whether to disallow the generator to generate the exact
same token in the original sentence same token in the original sentence
""" """
...@@ -120,13 +120,13 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -120,13 +120,13 @@ class ElectraPretrainer(tf.keras.Model):
Returns: Returns:
outputs: A dict of pretrainer model outputs, including outputs: A dict of pretrainer model outputs, including
(1) lm_outputs: a [batch_size, num_token_predictions, vocab_size] tensor (1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
indicating logits on masked positions. tensor indicating logits on masked positions.
(2) sentence_outputs: a [batch_size, num_classes] tensor indicating (2) sentence_outputs: A `[batch_size, num_classes]` tensor indicating
logits for nsp task. logits for nsp task.
(3) disc_logits: a [batch_size, sequence_length] tensor indicating (3) disc_logits: A `[batch_size, sequence_length]` tensor indicating
logits for discriminator replaced token detection task. logits for discriminator replaced token detection task.
(4) disc_label: a [batch_size, sequence_length] tensor indicating (4) disc_label: A `[batch_size, sequence_length]` tensor indicating
target labels for discriminator replaced token detection task. target labels for discriminator replaced token detection task.
""" """
input_word_ids = inputs['input_word_ids'] input_word_ids = inputs['input_word_ids']
...@@ -176,7 +176,7 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -176,7 +176,7 @@ class ElectraPretrainer(tf.keras.Model):
"""Generate corrupted data for discriminator. """Generate corrupted data for discriminator.
Args: Args:
inputs: A dict of all inputs, same as the input of call() function inputs: A dict of all inputs, same as the input of `call()` function
mlm_logits: The generator's output logits mlm_logits: The generator's output logits
duplicate: Whether to copy the original inputs dict during modifications duplicate: Whether to copy the original inputs dict during modifications
...@@ -227,16 +227,18 @@ def scatter_update(sequence, updates, positions): ...@@ -227,16 +227,18 @@ def scatter_update(sequence, updates, positions):
"""Scatter-update a sequence. """Scatter-update a sequence.
Args: Args:
sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor sequence: A `[batch_size, seq_len]` or `[batch_size, seq_len, depth]`
updates: A tensor of size batch_size*seq_len(*depth) tensor.
positions: A [batch_size, n_positions] tensor updates: A tensor of size `batch_size*seq_len(*depth)`.
positions: A `[batch_size, n_positions]` tensor.
Returns: Returns:
updated_sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] updated_sequence: A `[batch_size, seq_len]` or
tensor of "sequence" with elements at "positions" replaced by the values `[batch_size, seq_len, depth]` tensor of "sequence" with elements at
at "updates". Updates to index 0 are ignored. If there are duplicated "positions" replaced by the values at "updates". Updates to index 0 are
positions the update is only applied once. ignored. If there are duplicated positions the update is only
updates_mask: A [batch_size, seq_len] mask tensor of which inputs were applied once.
updates_mask: A `[batch_size, seq_len]` mask tensor of which inputs were
updated. updated.
""" """
shape = tf_utils.get_shape_list(sequence, expected_rank=[2, 3]) shape = tf_utils.get_shape_list(sequence, expected_rank=[2, 3])
...@@ -289,14 +291,14 @@ def sample_from_softmax(logits, disallow=None): ...@@ -289,14 +291,14 @@ def sample_from_softmax(logits, disallow=None):
"""Implement softmax sampling using gumbel softmax trick. """Implement softmax sampling using gumbel softmax trick.
Args: Args:
logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating logits: A `[batch_size, num_token_predictions, vocab_size]` tensor
the generator output logits for each masked position. indicating the generator output logits for each masked position.
disallow: If `None`, we directly sample tokens from the logits. Otherwise, disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size] this is a tensor of size `[batch_size, num_token_predictions, vocab_size]`
indicating the true word id in each masked position. indicating the true word id in each masked position.
Returns: Returns:
sampled_tokens: A [batch_size, num_token_predictions, vocab_size] one hot sampled_tokens: A `[batch_size, num_token_predictions, vocab_size]` one hot
tensor indicating the sampled word id in each masked position. tensor indicating the sampled word id in each masked position.
""" """
if disallow is not None: if disallow is not None:
......
...@@ -135,18 +135,19 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -135,18 +135,19 @@ class Seq2SeqTransformer(tf.keras.Model):
Args: Args:
inputs: a dictionary of tensors. inputs: a dictionary of tensors.
Feature `inputs`: int tensor with shape [batch_size, input_length]. Feature `inputs`: int tensor with shape `[batch_size, input_length]`.
Feature `targets` (optional): None or int tensor with shape Feature `targets` (optional): None or int tensor with shape
[batch_size, target_length]. `[batch_size, target_length]`.
Returns: Returns:
If targets is defined, then return logits for each word in the target If targets is defined, then return logits for each word in the target
sequence. float tensor with shape [batch_size, target_length, vocab_size] sequence, which is a float tensor with shape
If target is none, then generate output sequence one token at a time. `(batch_size, target_length, vocab_size)`. If target is `None`, then
generate output sequence one token at a time and
returns a dictionary { returns a dictionary {
outputs: [batch_size, decoded length] outputs: `(batch_size, decoded_length)`
scores: [batch_size, float]} scores: `(batch_size, 1)`}
Even when float16 is used, the output tensor(s) are always float32. Even when `float16` is used, the output tensor(s) are always `float32`.
Raises: Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs. NotImplementedError: If try to use padded decode method on CPU/GPUs.
...@@ -288,15 +289,15 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -288,15 +289,15 @@ class Seq2SeqTransformer(tf.keras.Model):
"""Generate logits for next potential IDs. """Generate logits for next potential IDs.
Args: Args:
ids: Current decoded sequences. int tensor with shape [batch_size * ids: Current decoded sequences. int tensor with shape
beam_size, i + 1]. `(batch_size * beam_size, i + 1)`.
i: Loop index. i: Loop index.
cache: dictionary of values storing the encoder output, encoder-decoder cache: Dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values. attention bias, and previous decoder attention values.
Returns: Returns:
Tuple of Tuple of
(logits with shape [batch_size * beam_size, vocab_size], (logits with shape `(batch_size * beam_size, vocab_size)`,
updated cache values) updated cache values)
""" """
# Set decoder input to the last generated IDs # Set decoder input to the last generated IDs
...@@ -440,13 +441,14 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -440,13 +441,14 @@ class TransformerEncoder(tf.keras.layers.Layer):
"""Return the output of the encoder. """Return the output of the encoder.
Args: Args:
encoder_inputs: tensor with shape [batch_size, input_length, hidden_size] encoder_inputs: A tensor with shape
attention_mask: mask for the encoder self-attention layer. [batch_size, `(batch_size, input_length, hidden_size)`.
input_length, input_length] attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`.
Returns: Returns:
Output of encoder. Output of encoder which is a `float32` tensor with shape
float32 tensor with shape [batch_size, input_length, hidden_size] `(batch_size, input_length, hidden_size)`.
""" """
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_layers):
encoder_inputs = self.encoder_layers[layer_idx]( encoder_inputs = self.encoder_layers[layer_idx](
...@@ -475,11 +477,11 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -475,11 +477,11 @@ class TransformerDecoder(tf.keras.layers.Layer):
activation: Activation for the intermediate layer. activation: Activation for the intermediate layer.
dropout_rate: Dropout probability. dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers. attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set False, use_bias: Whether to enable use_bias in attention layer. If set `False`,
use_bias in attention layer is disabled. use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is layers. If set `False`, output of attention and intermediate dense layers
normalized. is normalized.
norm_epsilon: Epsilon value to initialize normalization layers. norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer. intermediate_dropout: Dropout probability for intermediate_dropout_layer.
""" """
...@@ -555,23 +557,25 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -555,23 +557,25 @@ class TransformerDecoder(tf.keras.layers.Layer):
"""Return the output of the decoder layer stacks. """Return the output of the decoder layer stacks.
Args: Args:
target: A tensor with shape [batch_size, target_length, hidden_size]. target: A tensor with shape `(batch_size, target_length, hidden_size)`.
memory: A tensor with shape [batch_size, input_length, hidden_size] memory: A tensor with shape `(batch_size, input_length, hidden_size)`.
memory_mask: A tensor with shape [batch_size, target_len, target_length], memory_mask: A tensor with shape
the mask for decoder self-attention layer. `(batch_size, target_len, target_length)`, the mask for decoder
target_mask: A tensor with shape [batch_size, target_length, input_length] self-attention layer.
which is the mask for encoder-decoder attention layer. target_mask: A tensor with shape
`(batch_size, target_length, input_length)` which is the mask for
encoder-decoder attention layer.
cache: (Used for fast decoding) A nested dictionary storing previous cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are: decoder self-attention values. The items are:
{layer_n: {"k": A tensor with shape [batch_size, i, key_channels], {layer_n: {"k": A tensor with shape `(batch_size, i, key_channels)`,
"v": A tensor with shape [batch_size, i, value_channels]}, "v": A tensor with shape `(batch_size, i, value_channels)`},
...} ...}
decode_loop_step: An integer, the step number of the decoding loop. Used decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU. only for autoregressive inference on TPU.
Returns: Returns:
Output of decoder. Output of decoder.
float32 tensor with shape [batch_size, target_length, hidden_size] float32 tensor with shape `(batch_size, target_length, hidden_size`).
""" """
output_tensor = target output_tensor = target
......
...@@ -43,9 +43,9 @@ class AlbertEncoder(tf.keras.Model): ...@@ -43,9 +43,9 @@ class AlbertEncoder(tf.keras.Model):
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width is embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and matrices in the shape of `(vocab_size, embedding_width)` and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much `(embedding_width, hidden_size)`, where `embedding_width` is usually much
smaller than 'hidden_size'). smaller than `hidden_size`.
hidden_size: The size of the transformer hidden layers. hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers. num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The num_attention_heads: The number of attention heads for each transformer. The
......
...@@ -69,9 +69,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -69,9 +69,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
output. output.
embedding_width: The width of the word embeddings. If the embedding width is embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and matrices in the shape of `(vocab_size, embedding_width)` and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much `(embedding_width, hidden_size)`, where `embedding_width` is usually much
smaller than 'hidden_size'). smaller than `hidden_size`.
embedding_layer: The word embedding layer. `None` means we will create a new embedding_layer: The word embedding layer. `None` means we will create a new
embedding layer. Otherwise, we will reuse the given embedding layer. This embedding layer. Otherwise, we will reuse the given embedding layer. This
parameter is originally added for ELECTRA model which needs to tie the parameter is originally added for ELECTRA model which needs to tie the
......
...@@ -35,8 +35,8 @@ class Classification(tf.keras.Model): ...@@ -35,8 +35,8 @@ class Classification(tf.keras.Model):
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
initializer: The initializer for the dense layer in this network. Defaults initializer: The initializer for the dense layer in this network. Defaults
to a Glorot uniform initializer. to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
""" """
def __init__(self, def __init__(self,
......
...@@ -38,7 +38,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -38,7 +38,7 @@ class EncoderScaffold(tf.keras.Model):
class (which will replace the Transformer instantiation in the encoder). For class (which will replace the Transformer instantiation in the encoder). For
each of these custom injection points, users can pass either a class or a each of these custom injection points, users can pass either a class or a
class instance. If a class is passed, that class will be instantiated using class instance. If a class is passed, that class will be instantiated using
the 'embedding_cfg' or 'hidden_cfg' argument, respectively; if an instance the `embedding_cfg` or `hidden_cfg` argument, respectively; if an instance
is passed, that instance will be invoked. (In the case of hidden_cls, the is passed, that instance will be invoked. (In the case of hidden_cls, the
instance will be invoked 'num_hidden_instances' times. instance will be invoked 'num_hidden_instances' times.
...@@ -53,40 +53,41 @@ class EncoderScaffold(tf.keras.Model): ...@@ -53,40 +53,41 @@ class EncoderScaffold(tf.keras.Model):
pooler_layer_initializer: The initializer for the classification layer. pooler_layer_initializer: The initializer for the classification layer.
embedding_cls: The class or instance to use to embed the input data. This embedding_cls: The class or instance to use to embed the input data. This
class or instance defines the inputs to this encoder and outputs (1) class or instance defines the inputs to this encoder and outputs (1)
embeddings tensor with shape [batch_size, seq_length, hidden_size] and (2) embeddings tensor with shape `(batch_size, seq_length, hidden_size)` and
attention masking with tensor [batch_size, seq_length, seq_length]. If (2) attention masking with tensor `(batch_size, seq_length, seq_length)`.
embedding_cls is not set, a default embedding network (from the original If `embedding_cls` is not set, a default embedding network (from the
BERT paper) will be created. original BERT paper) will be created.
embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to
be instantiated. If embedding_cls is not set, a config dict must be be instantiated. If `embedding_cls` is not set, a config dict must be
passed to 'embedding_cfg' with the following values: passed to `embedding_cfg` with the following values:
"vocab_size": The size of the token vocabulary. `vocab_size`: The size of the token vocabulary.
"type_vocab_size": The size of the type vocabulary. `type_vocab_size`: The size of the type vocabulary.
"hidden_size": The hidden size for this encoder. `hidden_size`: The hidden size for this encoder.
"max_seq_length": The maximum sequence length for this encoder. `max_seq_length`: The maximum sequence length for this encoder.
"seq_length": The sequence length for this encoder. `seq_length`: The sequence length for this encoder.
"initializer": The initializer for the embedding portion of this encoder. `initializer`: The initializer for the embedding portion of this encoder.
"dropout_rate": The dropout rate to apply before the encoding layers. `dropout_rate`: The dropout rate to apply before the encoding layers.
embedding_data: A reference to the embedding weights that will be used to embedding_data: A reference to the embedding weights that will be used to
train the masked language model, if necessary. This is optional, and only train the masked language model, if necessary. This is optional, and only
needed if (1) you are overriding embedding_cls and (2) are doing standard needed if (1) you are overriding `embedding_cls` and (2) are doing
pretraining. standard pretraining.
num_hidden_instances: The number of times to instantiate and/or invoke the num_hidden_instances: The number of times to instantiate and/or invoke the
hidden_cls. hidden_cls.
hidden_cls: The class or instance to encode the input data. If hidden_cls is hidden_cls: The class or instance to encode the input data. If `hidden_cls`
not set, a KerasBERT transformer layer will be used as the encoder class. is not set, a KerasBERT transformer layer will be used as the encoder
class.
hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be
instantiated. If hidden_cls is not set, a config dict must be passed to instantiated. If hidden_cls is not set, a config dict must be passed to
'hidden_cfg' with the following values: `hidden_cfg` with the following values:
"num_attention_heads": The number of attention heads. The hidden size `num_attention_heads`: The number of attention heads. The hidden size
must be divisible by num_attention_heads. must be divisible by `num_attention_heads`.
"intermediate_size": The intermediate size of the transformer. `intermediate_size`: The intermediate size of the transformer.
"intermediate_activation": The activation to apply in the transfomer. `intermediate_activation`: The activation to apply in the transfomer.
"dropout_rate": The overall dropout rate for the transformer layers. `dropout_rate`: The overall dropout rate for the transformer layers.
"attention_dropout_rate": The dropout rate for the attention layers. `attention_dropout_rate`: The dropout rate for the attention layers.
"kernel_initializer": The initializer for the transformer layers. `kernel_initializer`: The initializer for the transformer layers.
layer_norm_before_pooling: Whether to add a layer norm before the pooling layer_norm_before_pooling: Whether to add a layer norm before the pooling
layer. You probably want to turn this on if you set norm_first=True in layer. You probably want to turn this on if you set `norm_first=True` in
transformer layers. transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers. all encoder transformer layers.
......
...@@ -63,7 +63,7 @@ class MobileBERTEncoder(tf.keras.Model): ...@@ -63,7 +63,7 @@ class MobileBERTEncoder(tf.keras.Model):
attention_probs_dropout_prob: Dropout probability of the attention attention_probs_dropout_prob: Dropout probability of the attention
probabilities. probabilities.
intra_bottleneck_size: Size of bottleneck. intra_bottleneck_size: Size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for initializer_range: The stddev of the `truncated_normal_initializer` for
initializing all weight matrices. initializing all weight matrices.
use_bottleneck_attention: Use attention inputs from the bottleneck use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck` transformation. If true, the following `key_query_shared_bottleneck`
...@@ -71,17 +71,17 @@ class MobileBERTEncoder(tf.keras.Model): ...@@ -71,17 +71,17 @@ class MobileBERTEncoder(tf.keras.Model):
key_query_shared_bottleneck: Whether to share linear transformation for key_query_shared_bottleneck: Whether to share linear transformation for
keys and queries. keys and queries.
num_feedforward_networks: Number of stacked feed-forward networks. num_feedforward_networks: Number of stacked feed-forward networks.
normalization_type: The type of normalization_type, only 'no_norm' and normalization_type: The type of normalization_type, only `no_norm` and
'layer_norm' are supported. 'no_norm' represents the element-wise linear `layer_norm` are supported. `no_norm` represents the element-wise linear
transformation for the student model, as suggested by the original transformation for the student model, as suggested by the original
MobileBERT paper. 'layer_norm' is used for the teacher model. MobileBERT paper. `layer_norm` is used for the teacher model.
classifier_activation: If using the tanh activation for the final classifier_activation: If using the tanh activation for the final
representation of the [CLS] token in fine-tuning. representation of the `[CLS]` token in fine-tuning.
input_mask_dtype: The dtype of `input_mask` tensor, which is one of the input_mask_dtype: The dtype of `input_mask` tensor, which is one of the
input tensors of this encoder. Defaults to `int32`. If you want input tensors of this encoder. Defaults to `int32`. If you want
to use `tf.lite` quantization, which does not support `Cast` op, to use `tf.lite` quantization, which does not support `Cast` op,
please set this argument to `tf.float32` and feed `input_mask` please set this argument to `tf.float32` and feed `input_mask`
tensor with values in float32 to avoid `tf.cast` in the computation. tensor with values in `float32` to avoid `tf.cast` in the computation.
**kwargs: Other keyworded and arguments. **kwargs: Other keyworded and arguments.
""" """
self._self_setattr_tracking = False self._self_setattr_tracking = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment