Commit 790e49e5 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into run_superglue

parents 8ab018b0 5bb827c3
...@@ -37,8 +37,8 @@ class DualEncoder(tf.keras.Model): ...@@ -37,8 +37,8 @@ class DualEncoder(tf.keras.Model):
normalize: If set to True, normalize the encoding produced by transfomer. normalize: If set to True, normalize the encoding produced by transfomer.
logit_scale: The scaling factor of dot products when doing training. logit_scale: The scaling factor of dot products when doing training.
logit_margin: The margin between positive and negative when doing training. logit_margin: The margin between positive and negative when doing training.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. If set to 'predictions', it will output the embedding `predictions`. If set to `predictions`, it will output the embedding
producted by transformer network. producted by transformer network.
""" """
......
...@@ -52,8 +52,8 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -52,8 +52,8 @@ class ElectraPretrainer(tf.keras.Model):
classification networks. If None, no activation will be used. classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer. classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either 'logits' or output_type: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
disallow_correct: Whether to disallow the generator to generate the exact disallow_correct: Whether to disallow the generator to generate the exact
same token in the original sentence same token in the original sentence
""" """
...@@ -120,13 +120,13 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -120,13 +120,13 @@ class ElectraPretrainer(tf.keras.Model):
Returns: Returns:
outputs: A dict of pretrainer model outputs, including outputs: A dict of pretrainer model outputs, including
(1) lm_outputs: a [batch_size, num_token_predictions, vocab_size] tensor (1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
indicating logits on masked positions. tensor indicating logits on masked positions.
(2) sentence_outputs: a [batch_size, num_classes] tensor indicating (2) sentence_outputs: A `[batch_size, num_classes]` tensor indicating
logits for nsp task. logits for nsp task.
(3) disc_logits: a [batch_size, sequence_length] tensor indicating (3) disc_logits: A `[batch_size, sequence_length]` tensor indicating
logits for discriminator replaced token detection task. logits for discriminator replaced token detection task.
(4) disc_label: a [batch_size, sequence_length] tensor indicating (4) disc_label: A `[batch_size, sequence_length]` tensor indicating
target labels for discriminator replaced token detection task. target labels for discriminator replaced token detection task.
""" """
input_word_ids = inputs['input_word_ids'] input_word_ids = inputs['input_word_ids']
...@@ -176,7 +176,7 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -176,7 +176,7 @@ class ElectraPretrainer(tf.keras.Model):
"""Generate corrupted data for discriminator. """Generate corrupted data for discriminator.
Args: Args:
inputs: A dict of all inputs, same as the input of call() function inputs: A dict of all inputs, same as the input of `call()` function
mlm_logits: The generator's output logits mlm_logits: The generator's output logits
duplicate: Whether to copy the original inputs dict during modifications duplicate: Whether to copy the original inputs dict during modifications
...@@ -227,16 +227,18 @@ def scatter_update(sequence, updates, positions): ...@@ -227,16 +227,18 @@ def scatter_update(sequence, updates, positions):
"""Scatter-update a sequence. """Scatter-update a sequence.
Args: Args:
sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor sequence: A `[batch_size, seq_len]` or `[batch_size, seq_len, depth]`
updates: A tensor of size batch_size*seq_len(*depth) tensor.
positions: A [batch_size, n_positions] tensor updates: A tensor of size `batch_size*seq_len(*depth)`.
positions: A `[batch_size, n_positions]` tensor.
Returns: Returns:
updated_sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] updated_sequence: A `[batch_size, seq_len]` or
tensor of "sequence" with elements at "positions" replaced by the values `[batch_size, seq_len, depth]` tensor of "sequence" with elements at
at "updates". Updates to index 0 are ignored. If there are duplicated "positions" replaced by the values at "updates". Updates to index 0 are
positions the update is only applied once. ignored. If there are duplicated positions the update is only
updates_mask: A [batch_size, seq_len] mask tensor of which inputs were applied once.
updates_mask: A `[batch_size, seq_len]` mask tensor of which inputs were
updated. updated.
""" """
shape = tf_utils.get_shape_list(sequence, expected_rank=[2, 3]) shape = tf_utils.get_shape_list(sequence, expected_rank=[2, 3])
...@@ -289,14 +291,14 @@ def sample_from_softmax(logits, disallow=None): ...@@ -289,14 +291,14 @@ def sample_from_softmax(logits, disallow=None):
"""Implement softmax sampling using gumbel softmax trick. """Implement softmax sampling using gumbel softmax trick.
Args: Args:
logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating logits: A `[batch_size, num_token_predictions, vocab_size]` tensor
the generator output logits for each masked position. indicating the generator output logits for each masked position.
disallow: If `None`, we directly sample tokens from the logits. Otherwise, disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size] this is a tensor of size `[batch_size, num_token_predictions, vocab_size]`
indicating the true word id in each masked position. indicating the true word id in each masked position.
Returns: Returns:
sampled_tokens: A [batch_size, num_token_predictions, vocab_size] one hot sampled_tokens: A `[batch_size, num_token_predictions, vocab_size]` one hot
tensor indicating the sampled word id in each masked position. tensor indicating the sampled word id in each masked position.
""" """
if disallow is not None: if disallow is not None:
......
# Networks # Networks
Networks are combinations of layers (and possibly other networks). Networks are combinations of `tf.keras` layers (and possibly other networks).
They are sub-units of models that would not be trained alone. It They are `tf.keras` models that would not be trained alone. It encapsulates
encapsulates common network structures like a classification head common network structures like a transformer encoder into an easily
or a transformer encoder into an easily handled object with a handled object with a standardized configuration.
standardized configuration.
* [`BertEncoder`](bert_encoder.py) implements a bi-directional * [`BertEncoder`](bert_encoder.py) implements a bi-directional
Transformer-based encoder as described in ["BERT: Pre-training of Deep Transformer-based encoder as described in ["BERT: Pre-training of Deep
......
...@@ -12,7 +12,12 @@ ...@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Networks package definition.""" """Networks are combinations of `tf.keras` layers (and possibly other networks).
They are `tf.keras` models that would not be trained alone. It encapsulates
common network structures like a transformer encoder into an easily
handled object with a standardized configuration.
"""
from official.nlp.modeling.networks.albert_encoder import AlbertEncoder from official.nlp.modeling.networks.albert_encoder import AlbertEncoder
from official.nlp.modeling.networks.bert_encoder import BertEncoder from official.nlp.modeling.networks.bert_encoder import BertEncoder
from official.nlp.modeling.networks.classification import Classification from official.nlp.modeling.networks.classification import Classification
......
...@@ -43,9 +43,9 @@ class AlbertEncoder(tf.keras.Model): ...@@ -43,9 +43,9 @@ class AlbertEncoder(tf.keras.Model):
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width is embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and matrices in the shape of `(vocab_size, embedding_width)` and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much `(embedding_width, hidden_size)`, where `embedding_width` is usually much
smaller than 'hidden_size'). smaller than `hidden_size`.
hidden_size: The size of the transformer hidden layers. hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers. num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The num_attention_heads: The number of attention heads for each transformer. The
......
...@@ -69,9 +69,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -69,9 +69,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
output. output.
embedding_width: The width of the word embeddings. If the embedding width is embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and matrices in the shape of `(vocab_size, embedding_width)` and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much `(embedding_width, hidden_size)`, where `embedding_width` is usually much
smaller than 'hidden_size'). smaller than `hidden_size`.
embedding_layer: The word embedding layer. `None` means we will create a new embedding_layer: The word embedding layer. `None` means we will create a new
embedding layer. Otherwise, we will reuse the given embedding layer. This embedding layer. Otherwise, we will reuse the given embedding layer. This
parameter is originally added for ELECTRA model which needs to tie the parameter is originally added for ELECTRA model which needs to tie the
......
...@@ -35,8 +35,8 @@ class Classification(tf.keras.Model): ...@@ -35,8 +35,8 @@ class Classification(tf.keras.Model):
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
initializer: The initializer for the dense layer in this network. Defaults initializer: The initializer for the dense layer in this network. Defaults
to a Glorot uniform initializer. to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
""" """
def __init__(self, def __init__(self,
......
...@@ -38,7 +38,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -38,7 +38,7 @@ class EncoderScaffold(tf.keras.Model):
class (which will replace the Transformer instantiation in the encoder). For class (which will replace the Transformer instantiation in the encoder). For
each of these custom injection points, users can pass either a class or a each of these custom injection points, users can pass either a class or a
class instance. If a class is passed, that class will be instantiated using class instance. If a class is passed, that class will be instantiated using
the 'embedding_cfg' or 'hidden_cfg' argument, respectively; if an instance the `embedding_cfg` or `hidden_cfg` argument, respectively; if an instance
is passed, that instance will be invoked. (In the case of hidden_cls, the is passed, that instance will be invoked. (In the case of hidden_cls, the
instance will be invoked 'num_hidden_instances' times. instance will be invoked 'num_hidden_instances' times.
...@@ -53,40 +53,41 @@ class EncoderScaffold(tf.keras.Model): ...@@ -53,40 +53,41 @@ class EncoderScaffold(tf.keras.Model):
pooler_layer_initializer: The initializer for the classification layer. pooler_layer_initializer: The initializer for the classification layer.
embedding_cls: The class or instance to use to embed the input data. This embedding_cls: The class or instance to use to embed the input data. This
class or instance defines the inputs to this encoder and outputs (1) class or instance defines the inputs to this encoder and outputs (1)
embeddings tensor with shape [batch_size, seq_length, hidden_size] and (2) embeddings tensor with shape `(batch_size, seq_length, hidden_size)` and
attention masking with tensor [batch_size, seq_length, seq_length]. If (2) attention masking with tensor `(batch_size, seq_length, seq_length)`.
embedding_cls is not set, a default embedding network (from the original If `embedding_cls` is not set, a default embedding network (from the
BERT paper) will be created. original BERT paper) will be created.
embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to
be instantiated. If embedding_cls is not set, a config dict must be be instantiated. If `embedding_cls` is not set, a config dict must be
passed to 'embedding_cfg' with the following values: passed to `embedding_cfg` with the following values:
"vocab_size": The size of the token vocabulary. `vocab_size`: The size of the token vocabulary.
"type_vocab_size": The size of the type vocabulary. `type_vocab_size`: The size of the type vocabulary.
"hidden_size": The hidden size for this encoder. `hidden_size`: The hidden size for this encoder.
"max_seq_length": The maximum sequence length for this encoder. `max_seq_length`: The maximum sequence length for this encoder.
"seq_length": The sequence length for this encoder. `seq_length`: The sequence length for this encoder.
"initializer": The initializer for the embedding portion of this encoder. `initializer`: The initializer for the embedding portion of this encoder.
"dropout_rate": The dropout rate to apply before the encoding layers. `dropout_rate`: The dropout rate to apply before the encoding layers.
embedding_data: A reference to the embedding weights that will be used to embedding_data: A reference to the embedding weights that will be used to
train the masked language model, if necessary. This is optional, and only train the masked language model, if necessary. This is optional, and only
needed if (1) you are overriding embedding_cls and (2) are doing standard needed if (1) you are overriding `embedding_cls` and (2) are doing
pretraining. standard pretraining.
num_hidden_instances: The number of times to instantiate and/or invoke the num_hidden_instances: The number of times to instantiate and/or invoke the
hidden_cls. hidden_cls.
hidden_cls: The class or instance to encode the input data. If hidden_cls is hidden_cls: The class or instance to encode the input data. If `hidden_cls`
not set, a KerasBERT transformer layer will be used as the encoder class. is not set, a KerasBERT transformer layer will be used as the encoder
class.
hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be
instantiated. If hidden_cls is not set, a config dict must be passed to instantiated. If hidden_cls is not set, a config dict must be passed to
'hidden_cfg' with the following values: `hidden_cfg` with the following values:
"num_attention_heads": The number of attention heads. The hidden size `num_attention_heads`: The number of attention heads. The hidden size
must be divisible by num_attention_heads. must be divisible by `num_attention_heads`.
"intermediate_size": The intermediate size of the transformer. `intermediate_size`: The intermediate size of the transformer.
"intermediate_activation": The activation to apply in the transfomer. `intermediate_activation`: The activation to apply in the transfomer.
"dropout_rate": The overall dropout rate for the transformer layers. `dropout_rate`: The overall dropout rate for the transformer layers.
"attention_dropout_rate": The dropout rate for the attention layers. `attention_dropout_rate`: The dropout rate for the attention layers.
"kernel_initializer": The initializer for the transformer layers. `kernel_initializer`: The initializer for the transformer layers.
layer_norm_before_pooling: Whether to add a layer norm before the pooling layer_norm_before_pooling: Whether to add a layer norm before the pooling
layer. You probably want to turn this on if you set norm_first=True in layer. You probably want to turn this on if you set `norm_first=True` in
transformer layers. transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers. all encoder transformer layers.
......
...@@ -63,7 +63,7 @@ class MobileBERTEncoder(tf.keras.Model): ...@@ -63,7 +63,7 @@ class MobileBERTEncoder(tf.keras.Model):
attention_probs_dropout_prob: Dropout probability of the attention attention_probs_dropout_prob: Dropout probability of the attention
probabilities. probabilities.
intra_bottleneck_size: Size of bottleneck. intra_bottleneck_size: Size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for initializer_range: The stddev of the `truncated_normal_initializer` for
initializing all weight matrices. initializing all weight matrices.
use_bottleneck_attention: Use attention inputs from the bottleneck use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck` transformation. If true, the following `key_query_shared_bottleneck`
...@@ -71,17 +71,17 @@ class MobileBERTEncoder(tf.keras.Model): ...@@ -71,17 +71,17 @@ class MobileBERTEncoder(tf.keras.Model):
key_query_shared_bottleneck: Whether to share linear transformation for key_query_shared_bottleneck: Whether to share linear transformation for
keys and queries. keys and queries.
num_feedforward_networks: Number of stacked feed-forward networks. num_feedforward_networks: Number of stacked feed-forward networks.
normalization_type: The type of normalization_type, only 'no_norm' and normalization_type: The type of normalization_type, only `no_norm` and
'layer_norm' are supported. 'no_norm' represents the element-wise linear `layer_norm` are supported. `no_norm` represents the element-wise linear
transformation for the student model, as suggested by the original transformation for the student model, as suggested by the original
MobileBERT paper. 'layer_norm' is used for the teacher model. MobileBERT paper. `layer_norm` is used for the teacher model.
classifier_activation: If using the tanh activation for the final classifier_activation: If using the tanh activation for the final
representation of the [CLS] token in fine-tuning. representation of the `[CLS]` token in fine-tuning.
input_mask_dtype: The dtype of `input_mask` tensor, which is one of the input_mask_dtype: The dtype of `input_mask` tensor, which is one of the
input tensors of this encoder. Defaults to `int32`. If you want input tensors of this encoder. Defaults to `int32`. If you want
to use `tf.lite` quantization, which does not support `Cast` op, to use `tf.lite` quantization, which does not support `Cast` op,
please set this argument to `tf.float32` and feed `input_mask` please set this argument to `tf.float32` and feed `input_mask`
tensor with values in float32 to avoid `tf.cast` in the computation. tensor with values in `float32` to avoid `tf.cast` in the computation.
**kwargs: Other keyworded and arguments. **kwargs: Other keyworded and arguments.
""" """
self._self_setattr_tracking = False self._self_setattr_tracking = False
......
...@@ -160,6 +160,8 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -160,6 +160,8 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
prediction = classifier([word_ids, mask, type_ids]) prediction = classifier([word_ids, mask, type_ids])
if task == models.BertTokenClassifier:
prediction = prediction['logits']
self.assertAllEqual(prediction.shape.as_list(), prediction_shape) self.assertAllEqual(prediction.shape.as_list(), prediction_shape)
......
...@@ -40,14 +40,14 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -40,14 +40,14 @@ class PackedSequenceEmbedding(tf.keras.Model):
max_seq_length: The maximum sequence length for this encoder. max_seq_length: The maximum sequence length for this encoder.
initializer: The initializer for the embedding portion of this encoder. initializer: The initializer for the embedding portion of this encoder.
dropout_rate: The dropout rate to apply before the encoding layers. dropout_rate: The dropout rate to apply before the encoding layers.
pack_multiple_sequences: If True, we can feed multiple sequences into one pack_multiple_sequences: If `True`, we can feed multiple sequences into one
sequence for training and inference (they don't impact each other). sequence for training and inference (they don't impact each other).
use_position_id: Whether to expect `position_ids` as an input to the use_position_id: Whether to expect `position_ids` as an input to the
network. If False, the `position_ids` will be inferred: (1) when network. If False, the `position_ids` will be inferred: (1) when
pack_multiple_sequences is False, we assume the position ids are 0, 1, pack_multiple_sequences is False, we assume the position ids are `0, 1,
2, ..., seq_length - 1; (2) when pack_multiple_sequences is True, there 2, ..., seq_length - 1`; (2) when `pack_multiple_sequences` is `True`,
may be multiple sub sequences, and for each sub sequence, its position there may be multiple sub sequences, and for each sub sequence, its
ids start from 0, 1, 2, ... position ids start from 0, 1, 2, ...
""" """
def __init__(self, def __init__(self,
......
...@@ -37,8 +37,8 @@ class SpanLabeling(tf.keras.Model): ...@@ -37,8 +37,8 @@ class SpanLabeling(tf.keras.Model):
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
initializer: The initializer for the dense layer in this network. Defaults initializer: The initializer for the dense layer in this network. Defaults
to a Glorot uniform initializer. to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
""" """
def __init__(self, def __init__(self,
...@@ -228,20 +228,20 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -228,20 +228,20 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
Args: Args:
sequence_data: The input sequence data of shape sequence_data: The input sequence data of shape
(batch_size, seq_length, input_width). `(batch_size, seq_length, input_width)`.
class_index: The class indices of the inputs of shape (batch_size,). class_index: The class indices of the inputs of shape `(batch_size,)`.
paragraph_mask: Invalid position mask such as query and special symbols paragraph_mask: Invalid position mask such as query and special symbols
(e.g. PAD, SEP, CLS) of shape (batch_size,). (e.g. PAD, SEP, CLS) of shape `(batch_size,)`.
start_positions: The start positions of each example of shape start_positions: The start positions of each example of shape
(batch_size,). `(batch_size,)`.
training: Whether or not this is the training phase. training: Whether or not this is the training phase.
Returns: Returns:
A dictionary with the keys 'start_predictions', 'end_predictions', A dictionary with the keys `start_predictions`, `end_predictions`,
'start_logits', 'end_logits'. `start_logits`, `end_logits`.
If inference, then 'start_top_predictions', 'start_top_index', If inference, then `start_top_predictions`, `start_top_index`,
'end_top_predictions', 'end_top_index' are also included. `end_top_predictions`, `end_top_index` are also included.
""" """
paragraph_mask = tf.cast(paragraph_mask, dtype=sequence_data.dtype) paragraph_mask = tf.cast(paragraph_mask, dtype=sequence_data.dtype)
......
...@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Tuple ...@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Tuple
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from official.modeling import tf_utils
Output = Tuple[tf.Tensor, tf.Tensor] Output = Tuple[tf.Tensor, tf.Tensor]
InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict] InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict]
...@@ -64,15 +65,7 @@ def log_prob_from_logits(logits): ...@@ -64,15 +65,7 @@ def log_prob_from_logits(logits):
def shape_list(tensor): def shape_list(tensor):
"""Return a list of the tensor's shape, and ensure no None values in list.""" """Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions) return tf_utils.get_shape_list(tensor)
shape = tensor.get_shape().as_list()
# Ensure that the shape values are not None
dynamic_shape = tf.shape(tensor)
for i in range(len(shape)): # pylint: disable=consider-using-enumerate
if shape[i] is None:
shape[i] = dynamic_shape[i]
return shape
def get_shape_keep_last_dim(tensor): def get_shape_keep_last_dim(tensor):
......
...@@ -76,14 +76,15 @@ def sample_top_p(logits, top_p): ...@@ -76,14 +76,15 @@ def sample_top_p(logits, top_p):
""" """
sorted_indices = tf.argsort(logits, direction="DESCENDING") sorted_indices = tf.argsort(logits, direction="DESCENDING")
# Flatten logits as tf.gather on TPU needs axis to be compile time constant. # Flatten logits as tf.gather on TPU needs axis to be compile time constant.
range_for_gather = tf.expand_dims(tf.range(0, logits.shape[0]), axis=1) logits_shape = decoding_module.shape_list(logits)
range_for_gather = tf.tile(range_for_gather * logits.shape[1], range_for_gather = tf.expand_dims(tf.range(0, logits_shape[0]), axis=1)
[1, logits.shape[1]]) + sorted_indices range_for_gather = tf.tile(range_for_gather * logits_shape[1],
[1, logits_shape[1]]) + sorted_indices
flattened_logits = tf.reshape(logits, [-1]) flattened_logits = tf.reshape(logits, [-1])
flattened_sorted_indices = tf.reshape(range_for_gather, [-1]) flattened_sorted_indices = tf.reshape(range_for_gather, [-1])
sorted_logits = tf.reshape( sorted_logits = tf.reshape(
tf.gather(flattened_logits, flattened_sorted_indices), tf.gather(flattened_logits, flattened_sorted_indices),
[logits.shape[0], logits.shape[1]]) [logits_shape[0], logits_shape[1]])
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
# Remove tokens with cumulative probability above the threshold. # Remove tokens with cumulative probability above the threshold.
......
...@@ -113,7 +113,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -113,7 +113,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
correct way of using L2 regularization/weight decay with Adam, since that will correct way of using L2 regularization/weight decay with Adam, since that will
interact with the m and v parameters in strange ways. interact with the m and v parameters in strange ways.
Instead we want ot decay the weights in a manner that doesn't interact with Instead we want to decay the weights in a manner that doesn't interact with
the m/v parameters. This is equivalent to adding the square of the weights to the m/v parameters. This is equivalent to adding the square of the weights to
the loss with plain (non-momentum) SGD. the loss with plain (non-momentum) SGD.
""" """
...@@ -171,7 +171,8 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -171,7 +171,8 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
# and passed the allreduced grads_and_vars. For now, the # and passed the allreduced grads_and_vars. For now, the
# clip_by_global_norm will be moved to before the explicit allreduce to # clip_by_global_norm will be moved to before the explicit allreduce to
# keep the math the same as TF 1 and pre TF 2.2 implementation. # keep the math the same as TF 1 and pre TF 2.2 implementation.
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) (grads, _) = tf.clip_by_global_norm(
grads, clip_norm=self.gradient_clip_norm)
return super(AdamWeightDecay, self).apply_gradients( return super(AdamWeightDecay, self).apply_gradients(
zip(grads, tvars), zip(grads, tvars),
name=name, name=name,
......
...@@ -98,13 +98,14 @@ class TaggingTask(base_task.Task): ...@@ -98,13 +98,14 @@ class TaggingTask(base_task.Task):
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.head_initializer_range), stddev=self.task_config.model.head_initializer_range),
dropout_rate=self.task_config.model.head_dropout, dropout_rate=self.task_config.model.head_dropout,
output='logits') output='logits',
output_encoder_outputs=True)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
model_outputs = tf.cast(model_outputs, tf.float32) logits = tf.cast(model_outputs['logits'], tf.float32)
masked_labels, masked_weights = _masked_labels_and_weights(labels) masked_labels, masked_weights = _masked_labels_and_weights(labels)
loss = tf.keras.losses.sparse_categorical_crossentropy( loss = tf.keras.losses.sparse_categorical_crossentropy(
masked_labels, model_outputs, from_logits=True) masked_labels, logits, from_logits=True)
numerator_loss = tf.reduce_sum(loss * masked_weights) numerator_loss = tf.reduce_sum(loss * masked_weights)
denominator_loss = tf.reduce_sum(masked_weights) denominator_loss = tf.reduce_sum(masked_weights)
loss = tf.math.divide_no_nan(numerator_loss, denominator_loss) loss = tf.math.divide_no_nan(numerator_loss, denominator_loss)
...@@ -139,7 +140,7 @@ class TaggingTask(base_task.Task): ...@@ -139,7 +140,7 @@ class TaggingTask(base_task.Task):
def inference_step(self, inputs, model: tf.keras.Model): def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step.""" """Performs the forward step."""
logits = model(inputs, training=False) logits = model(inputs, training=False)['logits']
return {'logits': logits, return {'logits': logits,
'predict_ids': tf.argmax(logits, axis=-1, output_type=tf.int32)} 'predict_ids': tf.argmax(logits, axis=-1, output_type=tf.int32)}
...@@ -156,7 +157,7 @@ class TaggingTask(base_task.Task): ...@@ -156,7 +157,7 @@ class TaggingTask(base_task.Task):
""" """
features, labels = inputs features, labels = inputs
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
loss = self.build_losses(labels=labels, model_outputs=outputs['logits']) loss = self.build_losses(labels=labels, model_outputs=outputs)
# Negative label ids are padding labels which should be ignored. # Negative label ids are padding labels which should be ignored.
real_label_index = tf.where(tf.greater_equal(labels, 0)) real_label_index = tf.where(tf.greater_equal(labels, 0))
......
...@@ -302,7 +302,6 @@ class TranslationTask(base_task.Task): ...@@ -302,7 +302,6 @@ class TranslationTask(base_task.Task):
logs = {self.loss: loss} logs = {self.loss: loss}
if metrics: if metrics:
self.process_metrics(metrics, inputs["targets"], outputs) self.process_metrics(metrics, inputs["targets"], outputs)
logs.update({m.name: m.result() for m in metrics})
return logs return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
......
...@@ -51,7 +51,7 @@ class MockTask(base_task.Task): ...@@ -51,7 +51,7 @@ class MockTask(base_task.Task):
def build_model(self, *arg, **kwargs): def build_model(self, *arg, **kwargs):
inputs = tf.keras.layers.Input(shape=(2,), name="random", dtype=tf.float32) inputs = tf.keras.layers.Input(shape=(2,), name="random", dtype=tf.float32)
outputs = tf.keras.layers.Dense( outputs = tf.keras.layers.Dense(
1, bias_initializer=tf.keras.initializers.Ones())( 1, bias_initializer=tf.keras.initializers.Ones(), name="dense_0")(
inputs) inputs)
network = tf.keras.Model(inputs=inputs, outputs=outputs) network = tf.keras.Model(inputs=inputs, outputs=outputs)
return MockModel(network) return MockModel(network)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment