Commit 790e49e5 authored by stephenwu's avatar stephenwu
Browse files

Merge branch 'master' of https://github.com/tensorflow/models into run_superglue

parents 8ab018b0 5bb827c3
...@@ -37,8 +37,8 @@ class DualEncoder(tf.keras.Model): ...@@ -37,8 +37,8 @@ class DualEncoder(tf.keras.Model):
normalize: If set to True, normalize the encoding produced by transfomer. normalize: If set to True, normalize the encoding produced by transfomer.
logit_scale: The scaling factor of dot products when doing training. logit_scale: The scaling factor of dot products when doing training.
logit_margin: The margin between positive and negative when doing training. logit_margin: The margin between positive and negative when doing training.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. If set to 'predictions', it will output the embedding `predictions`. If set to `predictions`, it will output the embedding
producted by transformer network. producted by transformer network.
""" """
......
...@@ -52,8 +52,8 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -52,8 +52,8 @@ class ElectraPretrainer(tf.keras.Model):
classification networks. If None, no activation will be used. classification networks. If None, no activation will be used.
mlm_initializer: The initializer (if any) to use in the masked LM and mlm_initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer. classification networks. Defaults to a Glorot uniform initializer.
output_type: The output style for this network. Can be either 'logits' or output_type: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
disallow_correct: Whether to disallow the generator to generate the exact disallow_correct: Whether to disallow the generator to generate the exact
same token in the original sentence same token in the original sentence
""" """
...@@ -120,13 +120,13 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -120,13 +120,13 @@ class ElectraPretrainer(tf.keras.Model):
Returns: Returns:
outputs: A dict of pretrainer model outputs, including outputs: A dict of pretrainer model outputs, including
(1) lm_outputs: a [batch_size, num_token_predictions, vocab_size] tensor (1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
indicating logits on masked positions. tensor indicating logits on masked positions.
(2) sentence_outputs: a [batch_size, num_classes] tensor indicating (2) sentence_outputs: A `[batch_size, num_classes]` tensor indicating
logits for nsp task. logits for nsp task.
(3) disc_logits: a [batch_size, sequence_length] tensor indicating (3) disc_logits: A `[batch_size, sequence_length]` tensor indicating
logits for discriminator replaced token detection task. logits for discriminator replaced token detection task.
(4) disc_label: a [batch_size, sequence_length] tensor indicating (4) disc_label: A `[batch_size, sequence_length]` tensor indicating
target labels for discriminator replaced token detection task. target labels for discriminator replaced token detection task.
""" """
input_word_ids = inputs['input_word_ids'] input_word_ids = inputs['input_word_ids']
...@@ -176,7 +176,7 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -176,7 +176,7 @@ class ElectraPretrainer(tf.keras.Model):
"""Generate corrupted data for discriminator. """Generate corrupted data for discriminator.
Args: Args:
inputs: A dict of all inputs, same as the input of call() function inputs: A dict of all inputs, same as the input of `call()` function
mlm_logits: The generator's output logits mlm_logits: The generator's output logits
duplicate: Whether to copy the original inputs dict during modifications duplicate: Whether to copy the original inputs dict during modifications
...@@ -227,16 +227,18 @@ def scatter_update(sequence, updates, positions): ...@@ -227,16 +227,18 @@ def scatter_update(sequence, updates, positions):
"""Scatter-update a sequence. """Scatter-update a sequence.
Args: Args:
sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor sequence: A `[batch_size, seq_len]` or `[batch_size, seq_len, depth]`
updates: A tensor of size batch_size*seq_len(*depth) tensor.
positions: A [batch_size, n_positions] tensor updates: A tensor of size `batch_size*seq_len(*depth)`.
positions: A `[batch_size, n_positions]` tensor.
Returns: Returns:
updated_sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] updated_sequence: A `[batch_size, seq_len]` or
tensor of "sequence" with elements at "positions" replaced by the values `[batch_size, seq_len, depth]` tensor of "sequence" with elements at
at "updates". Updates to index 0 are ignored. If there are duplicated "positions" replaced by the values at "updates". Updates to index 0 are
positions the update is only applied once. ignored. If there are duplicated positions the update is only
updates_mask: A [batch_size, seq_len] mask tensor of which inputs were applied once.
updates_mask: A `[batch_size, seq_len]` mask tensor of which inputs were
updated. updated.
""" """
shape = tf_utils.get_shape_list(sequence, expected_rank=[2, 3]) shape = tf_utils.get_shape_list(sequence, expected_rank=[2, 3])
...@@ -289,14 +291,14 @@ def sample_from_softmax(logits, disallow=None): ...@@ -289,14 +291,14 @@ def sample_from_softmax(logits, disallow=None):
"""Implement softmax sampling using gumbel softmax trick. """Implement softmax sampling using gumbel softmax trick.
Args: Args:
logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating logits: A `[batch_size, num_token_predictions, vocab_size]` tensor
the generator output logits for each masked position. indicating the generator output logits for each masked position.
disallow: If `None`, we directly sample tokens from the logits. Otherwise, disallow: If `None`, we directly sample tokens from the logits. Otherwise,
this is a tensor of size [batch_size, num_token_predictions, vocab_size] this is a tensor of size `[batch_size, num_token_predictions, vocab_size]`
indicating the true word id in each masked position. indicating the true word id in each masked position.
Returns: Returns:
sampled_tokens: A [batch_size, num_token_predictions, vocab_size] one hot sampled_tokens: A `[batch_size, num_token_predictions, vocab_size]` one hot
tensor indicating the sampled word id in each masked position. tensor indicating the sampled word id in each masked position.
""" """
if disallow is not None: if disallow is not None:
......
...@@ -23,10 +23,8 @@ from official.modeling import tf_utils ...@@ -23,10 +23,8 @@ from official.modeling import tf_utils
from official.nlp import keras_nlp from official.nlp import keras_nlp
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling.ops import beam_search from official.nlp.modeling.ops import beam_search
from official.nlp.transformer import model_utils
EOS_ID = 1 EOS_ID = 1
# pylint: disable=g-classes-have-attributes
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -52,7 +50,6 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -52,7 +50,6 @@ class Seq2SeqTransformer(tf.keras.Model):
alpha=0.6, alpha=0.6,
encoder_layer=None, encoder_layer=None,
decoder_layer=None, decoder_layer=None,
dtype=tf.float32,
eos_id=EOS_ID, eos_id=EOS_ID,
**kwargs): **kwargs):
"""Initialize layers to build Transformer model. """Initialize layers to build Transformer model.
...@@ -69,7 +66,6 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -69,7 +66,6 @@ class Seq2SeqTransformer(tf.keras.Model):
alpha: The strength of length normalization for beam search. alpha: The strength of length normalization for beam search.
encoder_layer: An initialized encoder layer. encoder_layer: An initialized encoder layer.
decoder_layer: An initialized decoder layer. decoder_layer: An initialized decoder layer.
dtype: float dtype.
eos_id: Id of end of sentence token. eos_id: Id of end of sentence token.
**kwargs: other keyword arguments. **kwargs: other keyword arguments.
""" """
...@@ -82,7 +78,6 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -82,7 +78,6 @@ class Seq2SeqTransformer(tf.keras.Model):
self._extra_decode_length = extra_decode_length self._extra_decode_length = extra_decode_length
self._beam_size = beam_size self._beam_size = beam_size
self._alpha = alpha self._alpha = alpha
self._dtype = dtype
self._eos_id = eos_id self._eos_id = eos_id
self.embedding_lookup = keras_nlp.layers.OnDeviceEmbedding( self.embedding_lookup = keras_nlp.layers.OnDeviceEmbedding(
vocab_size=self._vocab_size, vocab_size=self._vocab_size,
...@@ -104,7 +99,6 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -104,7 +99,6 @@ class Seq2SeqTransformer(tf.keras.Model):
"dropout_rate": self._dropout_rate, "dropout_rate": self._dropout_rate,
"padded_decode": self._padded_decode, "padded_decode": self._padded_decode,
"decode_max_length": self._decode_max_length, "decode_max_length": self._decode_max_length,
"dtype": self._dtype,
"eos_id": self._eos_id, "eos_id": self._eos_id,
"extra_decode_length": self._extra_decode_length, "extra_decode_length": self._extra_decode_length,
"beam_size": self._beam_size, "beam_size": self._beam_size,
...@@ -123,10 +117,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -123,10 +117,7 @@ class Seq2SeqTransformer(tf.keras.Model):
vocab_size = tf.shape(embedding_matrix)[0] vocab_size = tf.shape(embedding_matrix)[0]
x = tf.reshape(x, [-1, hidden_size]) x = tf.reshape(x, [-1, hidden_size])
logits = tf.matmul( logits = tf.matmul(x, tf.cast(embedding_matrix, x.dtype), transpose_b=True)
tf.cast(x, dtype=self._dtype),
tf.cast(embedding_matrix, self._dtype),
transpose_b=True)
return tf.reshape(logits, [batch_size, length, vocab_size]) return tf.reshape(logits, [batch_size, length, vocab_size])
...@@ -135,32 +126,29 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -135,32 +126,29 @@ class Seq2SeqTransformer(tf.keras.Model):
Args: Args:
inputs: a dictionary of tensors. inputs: a dictionary of tensors.
Feature `inputs`: int tensor with shape [batch_size, input_length]. Feature `inputs`: int tensor with shape `[batch_size, input_length]`.
Feature `targets` (optional): None or int tensor with shape Feature `targets` (optional): None or int tensor with shape
[batch_size, target_length]. `[batch_size, target_length]`.
Returns: Returns:
If targets is defined, then return logits for each word in the target If targets is defined, then return logits for each word in the target
sequence. float tensor with shape [batch_size, target_length, vocab_size] sequence, which is a float tensor with shape
If target is none, then generate output sequence one token at a time. `(batch_size, target_length, vocab_size)`. If target is `None`, then
returns a dictionary { generate output sequence one token at a time and
outputs: [batch_size, decoded length] returns a dictionary {
scores: [batch_size, float]} outputs: `(batch_size, decoded_length)`
Even when float16 is used, the output tensor(s) are always float32. scores: `(batch_size, 1)`}
Even when `float16` is used, the output tensor(s) are always `float32`.
Raises: Raises:
NotImplementedError: If try to use padded decode method on CPU/GPUs. NotImplementedError: If try to use padded decode method on CPU/GPUs.
""" """
sources = inputs["inputs"] sources = inputs["inputs"]
targets = inputs.get("targets", None) targets = inputs.get("targets", None)
attention_bias = model_utils.get_padding_bias(sources)
attention_bias = tf.cast(attention_bias, self._dtype)
# Prepare inputs to the layer stack by adding positional encodings and # Prepare inputs to the layer stack by adding positional encodings and
# applying dropout. # applying dropout.
embedded_inputs = self.embedding_lookup(sources) embedded_inputs = self.embedding_lookup(sources)
embedding_mask = tf.cast( embedding_mask = tf.cast(tf.not_equal(sources, 0), embedded_inputs.dtype)
tf.not_equal(sources, 0), self.embedding_lookup.embeddings.dtype)
embedded_inputs = tf.cast(embedded_inputs, self._dtype)
embedded_inputs *= tf.expand_dims(embedding_mask, -1) embedded_inputs *= tf.expand_dims(embedding_mask, -1)
# Attention_mask generation. # Attention_mask generation.
input_shape = tf_utils.get_shape_list(sources, expected_rank=2) input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
...@@ -172,8 +160,8 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -172,8 +160,8 @@ class Seq2SeqTransformer(tf.keras.Model):
shape=[input_shape[0], input_shape[1], 1], dtype=sources.dtype) shape=[input_shape[0], input_shape[1], 1], dtype=sources.dtype)
attention_mask = broadcast_ones * attention_mask attention_mask = broadcast_ones * attention_mask
pos_encoding = self.position_embedding(inputs=embedded_inputs) pos_encoding = self.position_embedding(embedded_inputs)
pos_encoding = tf.cast(pos_encoding, self._dtype) pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
encoder_inputs = embedded_inputs + pos_encoding encoder_inputs = embedded_inputs + pos_encoding
encoder_inputs = self.encoder_dropout(encoder_inputs) encoder_inputs = self.encoder_dropout(encoder_inputs)
...@@ -182,15 +170,11 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -182,15 +170,11 @@ class Seq2SeqTransformer(tf.keras.Model):
encoder_inputs, attention_mask=attention_mask) encoder_inputs, attention_mask=attention_mask)
if targets is None: if targets is None:
encoder_decoder_attention_bias = attention_bias
encoder_outputs = tf.cast(encoder_outputs, self._dtype)
if self._padded_decode: if self._padded_decode:
max_decode_length = self._decode_max_length max_decode_length = self._decode_max_length
else: else:
max_decode_length = self._decode_max_length or ( max_decode_length = self._decode_max_length or (
tf.shape(encoder_outputs)[1] + self._extra_decode_length) tf.shape(encoder_outputs)[1] + self._extra_decode_length)
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
self._dtype)
symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length) symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length)
batch_size = tf.shape(encoder_outputs)[0] batch_size = tf.shape(encoder_outputs)[0]
...@@ -198,28 +182,35 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -198,28 +182,35 @@ class Seq2SeqTransformer(tf.keras.Model):
initial_ids = tf.zeros([batch_size], dtype=tf.int32) initial_ids = tf.zeros([batch_size], dtype=tf.int32)
# Create cache storing decoder attention values for each layer. # Create cache storing decoder attention values for each layer.
# pylint: disable=g-complex-comprehension
init_decode_length = (max_decode_length if self._padded_decode else 0) init_decode_length = (max_decode_length if self._padded_decode else 0)
num_heads = self.decoder_layer.num_attention_heads num_heads = self.decoder_layer.num_attention_heads
dim_per_head = self._embedding_width // num_heads dim_per_head = self._embedding_width // num_heads
# Cache dtype needs to match beam_search dtype.
# pylint: disable=g-complex-comprehension
cache = { cache = {
str(layer): { str(layer): {
"key": "key":
tf.zeros( tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head], [batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self._dtype), dtype=self.compute_dtype),
"value": "value":
tf.zeros( tf.zeros(
[batch_size, init_decode_length, num_heads, dim_per_head], [batch_size, init_decode_length, num_heads, dim_per_head],
dtype=self._dtype) dtype=self.compute_dtype)
} for layer in range(self.decoder_layer.num_layers) } for layer in range(self.decoder_layer.num_layers)
} }
# pylint: enable=g-complex-comprehension # pylint: enable=g-complex-comprehension
# Add encoder output and attention bias to the cache. # Add encoder output and attention bias to the cache.
encoder_outputs = tf.cast(encoder_outputs, dtype=self.compute_dtype)
attention_mask = tf.cast(
tf.reshape(
tf.not_equal(sources, 0), [input_shape[0], 1, input_shape[1]]),
dtype=self.compute_dtype
)
cache["encoder_outputs"] = encoder_outputs cache["encoder_outputs"] = encoder_outputs
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias cache["encoder_decoder_attention_mask"] = attention_mask
# Use beam search to find the top beam_size sequences and scores. # Use beam search to find the top beam_size sequences and scores.
decoded_ids, scores = beam_search.sequence_beam_search( decoded_ids, scores = beam_search.sequence_beam_search(
...@@ -232,7 +223,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -232,7 +223,7 @@ class Seq2SeqTransformer(tf.keras.Model):
max_decode_length=max_decode_length, max_decode_length=max_decode_length,
eos_id=self._eos_id, eos_id=self._eos_id,
padded_decode=self._padded_decode, padded_decode=self._padded_decode,
dtype=self._dtype) dtype=self.compute_dtype)
# Get the top sequence for each batch element # Get the top sequence for each batch element
top_decoded_ids = decoded_ids[:, 0, 1:] top_decoded_ids = decoded_ids[:, 0, 1:]
...@@ -241,15 +232,13 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -241,15 +232,13 @@ class Seq2SeqTransformer(tf.keras.Model):
return {"outputs": top_decoded_ids, "scores": top_scores} return {"outputs": top_decoded_ids, "scores": top_scores}
decoder_inputs = self.embedding_lookup(targets) decoder_inputs = self.embedding_lookup(targets)
embedding_mask = tf.cast( embedding_mask = tf.cast(tf.not_equal(targets, 0), decoder_inputs.dtype)
tf.not_equal(targets, 0), self.embedding_lookup.embeddings.dtype)
decoder_inputs = tf.cast(decoder_inputs, self._dtype)
decoder_inputs *= tf.expand_dims(embedding_mask, -1) decoder_inputs *= tf.expand_dims(embedding_mask, -1)
# Shift targets to the right, and remove the last element # Shift targets to the right, and remove the last element
decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] decoder_inputs = tf.pad(decoder_inputs, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
length = tf.shape(decoder_inputs)[1] length = tf.shape(decoder_inputs)[1]
pos_encoding = self.position_embedding(decoder_inputs) pos_encoding = self.position_embedding(decoder_inputs)
pos_encoding = tf.cast(pos_encoding, self._dtype) pos_encoding = tf.cast(pos_encoding, embedded_inputs.dtype)
decoder_inputs += pos_encoding decoder_inputs += pos_encoding
decoder_inputs = self.decoder_dropout(decoder_inputs) decoder_inputs = self.decoder_dropout(decoder_inputs)
...@@ -258,8 +247,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -258,8 +247,7 @@ class Seq2SeqTransformer(tf.keras.Model):
batch_size = decoder_shape[0] batch_size = decoder_shape[0]
decoder_length = decoder_shape[1] decoder_length = decoder_shape[1]
self_attention_mask = tf.linalg.band_part( self_attention_mask = tf.linalg.band_part(tf.ones([length, length]), -1, 0)
tf.ones([length, length], dtype=tf.float32), -1, 0)
self_attention_mask = tf.reshape(self_attention_mask, [1, length, length]) self_attention_mask = tf.reshape(self_attention_mask, [1, length, length])
self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1]) self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
...@@ -273,6 +261,8 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -273,6 +261,8 @@ class Seq2SeqTransformer(tf.keras.Model):
memory_mask=self_attention_mask, memory_mask=self_attention_mask,
target_mask=attention_mask) target_mask=attention_mask)
logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs) logits = self._embedding_linear(self.embedding_lookup.embeddings, outputs)
# Model outputs should be float32 to avoid numeric issues.
# https://www.tensorflow.org/guide/mixed_precision#building_the_model
logits = tf.cast(logits, tf.float32) logits = tf.cast(logits, tf.float32)
return logits return logits
...@@ -280,23 +270,26 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -280,23 +270,26 @@ class Seq2SeqTransformer(tf.keras.Model):
"""Returns a decoding function that calculates logits of the next tokens.""" """Returns a decoding function that calculates logits of the next tokens."""
timing_signal = self.position_embedding( timing_signal = self.position_embedding(
inputs=None, length=max_decode_length + 1) inputs=None, length=max_decode_length + 1)
timing_signal = tf.cast(timing_signal, self._dtype) timing_signal = tf.cast(timing_signal, dtype=self.compute_dtype)
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( decoder_self_attention_mask = tf.linalg.band_part(
max_decode_length, dtype=self._dtype) tf.ones([max_decode_length, max_decode_length],
dtype=self.compute_dtype), -1, 0)
decoder_self_attention_mask = tf.reshape(
decoder_self_attention_mask, [1, max_decode_length, max_decode_length])
def symbols_to_logits_fn(ids, i, cache): def symbols_to_logits_fn(ids, i, cache):
"""Generate logits for next potential IDs. """Generate logits for next potential IDs.
Args: Args:
ids: Current decoded sequences. int tensor with shape [batch_size * ids: Current decoded sequences. int tensor with shape
beam_size, i + 1]. `(batch_size * beam_size, i + 1)`.
i: Loop index. i: Loop index.
cache: dictionary of values storing the encoder output, encoder-decoder cache: Dictionary of values storing the encoder output, encoder-decoder
attention bias, and previous decoder attention values. attention bias, and previous decoder attention values.
Returns: Returns:
Tuple of Tuple of
(logits with shape [batch_size * beam_size, vocab_size], (logits with shape `(batch_size * beam_size, vocab_size)`,
updated cache values) updated cache values)
""" """
# Set decoder input to the last generated IDs # Set decoder input to the last generated IDs
...@@ -307,33 +300,24 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -307,33 +300,24 @@ class Seq2SeqTransformer(tf.keras.Model):
source_decoder_input = decoder_input source_decoder_input = decoder_input
decoder_input = self.embedding_lookup(decoder_input) decoder_input = self.embedding_lookup(decoder_input)
embedding_mask = tf.cast( embedding_mask = tf.cast(
tf.not_equal(source_decoder_input, 0), tf.not_equal(source_decoder_input, 0), decoder_input.dtype)
self.embedding_lookup.embeddings.dtype)
decoder_input *= tf.expand_dims(embedding_mask, -1) decoder_input *= tf.expand_dims(embedding_mask, -1)
decoder_input += timing_signal[i] decoder_input += timing_signal[i]
if self._padded_decode: if self._padded_decode:
bias_shape = decoder_self_attention_bias.shape.as_list() # indexing does not work on TPU.
self_attention_bias = tf.slice( bias_shape = decoder_self_attention_mask.shape.as_list()
decoder_self_attention_bias, [0, 0, i, 0], self_attention_mask = tf.slice(
[bias_shape[0], bias_shape[1], 1, bias_shape[3]]) decoder_self_attention_mask, [0, i, 0],
[bias_shape[0], 1, bias_shape[2]])
else: else:
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] self_attention_mask = decoder_self_attention_mask[:, i:i+1, :i+1]
decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3) decoder_shape = tf_utils.get_shape_list(decoder_input, expected_rank=3)
batch_size = decoder_shape[0] batch_size = decoder_shape[0]
decoder_length = decoder_shape[1] decoder_length = decoder_shape[1]
attention_bias = cache.get("encoder_decoder_attention_bias") self_attention_mask = tf.tile(self_attention_mask, [batch_size, 1, 1])
attention_bias = tf.where(attention_bias < 0, attention_mask = cache.get("encoder_decoder_attention_mask")
tf.zeros_like(attention_bias), attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])
tf.ones_like(attention_bias))
attention_bias = tf.squeeze(attention_bias, axis=[1])
attention_mask = tf.tile(attention_bias, [1, decoder_length, 1])
self_attention_bias = tf.where(self_attention_bias < 0,
tf.zeros_like(self_attention_bias),
tf.ones_like(self_attention_bias))
self_attention_bias = tf.squeeze(self_attention_bias, axis=[1])
self_attention_mask = tf.tile(self_attention_bias, [batch_size, 1, 1])
decoder_outputs = self.decoder_layer( decoder_outputs = self.decoder_layer(
decoder_input, decoder_input,
...@@ -343,6 +327,7 @@ class Seq2SeqTransformer(tf.keras.Model): ...@@ -343,6 +327,7 @@ class Seq2SeqTransformer(tf.keras.Model):
cache=cache, cache=cache,
decode_loop_step=i if self._padded_decode else None) decode_loop_step=i if self._padded_decode else None)
decoder_outputs = tf.cast(decoder_outputs, dtype=self.compute_dtype)
logits = self._embedding_linear(self.embedding_lookup.embeddings, logits = self._embedding_linear(self.embedding_lookup.embeddings,
decoder_outputs) decoder_outputs)
logits = tf.squeeze(logits, axis=[1]) logits = tf.squeeze(logits, axis=[1])
...@@ -358,21 +343,6 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -358,21 +343,6 @@ class TransformerEncoder(tf.keras.layers.Layer):
of the sublayers: of the sublayers:
1. Self-attention layer 1. Self-attention layer
2. Feedforward network (which is 2 fully-connected layers) 2. Feedforward network (which is 2 fully-connected layers)
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
""" """
def __init__(self, def __init__(self,
...@@ -387,6 +357,25 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -387,6 +357,25 @@ class TransformerEncoder(tf.keras.layers.Layer):
norm_epsilon=1e-6, norm_epsilon=1e-6,
intermediate_dropout=0.0, intermediate_dropout=0.0,
**kwargs): **kwargs):
"""Initialize a Transformer encoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super(TransformerEncoder, self).__init__(**kwargs) super(TransformerEncoder, self).__init__(**kwargs)
self.num_layers = num_layers self.num_layers = num_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
...@@ -440,13 +429,14 @@ class TransformerEncoder(tf.keras.layers.Layer): ...@@ -440,13 +429,14 @@ class TransformerEncoder(tf.keras.layers.Layer):
"""Return the output of the encoder. """Return the output of the encoder.
Args: Args:
encoder_inputs: tensor with shape [batch_size, input_length, hidden_size] encoder_inputs: A tensor with shape
attention_mask: mask for the encoder self-attention layer. [batch_size, `(batch_size, input_length, hidden_size)`.
input_length, input_length] attention_mask: A mask for the encoder self-attention layer with shape
`(batch_size, input_length, input_length)`.
Returns: Returns:
Output of encoder. Output of encoder which is a `float32` tensor with shape
float32 tensor with shape [batch_size, input_length, hidden_size] `(batch_size, input_length, hidden_size)`.
""" """
for layer_idx in range(self.num_layers): for layer_idx in range(self.num_layers):
encoder_inputs = self.encoder_layers[layer_idx]( encoder_inputs = self.encoder_layers[layer_idx](
...@@ -467,21 +457,6 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -467,21 +457,6 @@ class TransformerDecoder(tf.keras.layers.Layer):
2. Multi-headed attention layer combining encoder outputs with results from 2. Multi-headed attention layer combining encoder outputs with results from
the previous self-attention layer. the previous self-attention layer.
3. Feedforward network (2 fully-connected layers) 3. Feedforward network (2 fully-connected layers)
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
""" """
def __init__(self, def __init__(self,
...@@ -496,6 +471,24 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -496,6 +471,24 @@ class TransformerDecoder(tf.keras.layers.Layer):
norm_epsilon=1e-6, norm_epsilon=1e-6,
intermediate_dropout=0.0, intermediate_dropout=0.0,
**kwargs): **kwargs):
"""Initialize a Transformer decoder.
Args:
num_layers: Number of layers.
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate (Feedforward) layer.
activation: Activation for the intermediate layer.
dropout_rate: Dropout probability.
attention_dropout_rate: Dropout probability for attention layers.
use_bias: Whether to enable use_bias in attention layer. If set `False`,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set `False`, output of attention and intermediate
dense layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
intermediate_dropout: Dropout probability for intermediate_dropout_layer.
**kwargs: key word arguemnts passed to tf.keras.layers.Layer.
"""
super(TransformerDecoder, self).__init__(**kwargs) super(TransformerDecoder, self).__init__(**kwargs)
self.num_layers = num_layers self.num_layers = num_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
...@@ -555,23 +548,25 @@ class TransformerDecoder(tf.keras.layers.Layer): ...@@ -555,23 +548,25 @@ class TransformerDecoder(tf.keras.layers.Layer):
"""Return the output of the decoder layer stacks. """Return the output of the decoder layer stacks.
Args: Args:
target: A tensor with shape [batch_size, target_length, hidden_size]. target: A tensor with shape `(batch_size, target_length, hidden_size)`.
memory: A tensor with shape [batch_size, input_length, hidden_size] memory: A tensor with shape `(batch_size, input_length, hidden_size)`.
memory_mask: A tensor with shape [batch_size, target_len, target_length], memory_mask: A tensor with shape
the mask for decoder self-attention layer. `(batch_size, target_len, target_length)`, the mask for decoder
target_mask: A tensor with shape [batch_size, target_length, input_length] self-attention layer.
which is the mask for encoder-decoder attention layer. target_mask: A tensor with shape
`(batch_size, target_length, input_length)` which is the mask for
encoder-decoder attention layer.
cache: (Used for fast decoding) A nested dictionary storing previous cache: (Used for fast decoding) A nested dictionary storing previous
decoder self-attention values. The items are: decoder self-attention values. The items are:
{layer_n: {"k": A tensor with shape [batch_size, i, key_channels], {layer_n: {"k": A tensor with shape `(batch_size, i, key_channels)`,
"v": A tensor with shape [batch_size, i, value_channels]}, "v": A tensor with shape `(batch_size, i, value_channels)`},
...} ...}
decode_loop_step: An integer, the step number of the decoding loop. Used decode_loop_step: An integer, the step number of the decoding loop. Used
only for autoregressive inference on TPU. only for autoregressive inference on TPU.
Returns: Returns:
Output of decoder. Output of decoder.
float32 tensor with shape [batch_size, target_length, hidden_size] float32 tensor with shape `(batch_size, target_length, hidden_size`).
""" """
output_tensor = target output_tensor = target
......
# Networks # Networks
Networks are combinations of layers (and possibly other networks). Networks are combinations of `tf.keras` layers (and possibly other networks).
They are sub-units of models that would not be trained alone. It They are `tf.keras` models that would not be trained alone. It encapsulates
encapsulates common network structures like a classification head common network structures like a transformer encoder into an easily
or a transformer encoder into an easily handled object with a handled object with a standardized configuration.
standardized configuration.
* [`BertEncoder`](bert_encoder.py) implements a bi-directional * [`BertEncoder`](bert_encoder.py) implements a bi-directional
Transformer-based encoder as described in ["BERT: Pre-training of Deep Transformer-based encoder as described in ["BERT: Pre-training of Deep
......
...@@ -12,7 +12,12 @@ ...@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Networks package definition.""" """Networks are combinations of `tf.keras` layers (and possibly other networks).
They are `tf.keras` models that would not be trained alone. It encapsulates
common network structures like a transformer encoder into an easily
handled object with a standardized configuration.
"""
from official.nlp.modeling.networks.albert_encoder import AlbertEncoder from official.nlp.modeling.networks.albert_encoder import AlbertEncoder
from official.nlp.modeling.networks.bert_encoder import BertEncoder from official.nlp.modeling.networks.bert_encoder import BertEncoder
from official.nlp.modeling.networks.classification import Classification from official.nlp.modeling.networks.classification import Classification
......
...@@ -43,9 +43,9 @@ class AlbertEncoder(tf.keras.Model): ...@@ -43,9 +43,9 @@ class AlbertEncoder(tf.keras.Model):
vocab_size: The size of the token vocabulary. vocab_size: The size of the token vocabulary.
embedding_width: The width of the word embeddings. If the embedding width is embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and matrices in the shape of `(vocab_size, embedding_width)` and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much `(embedding_width, hidden_size)`, where `embedding_width` is usually much
smaller than 'hidden_size'). smaller than `hidden_size`.
hidden_size: The size of the transformer hidden layers. hidden_size: The size of the transformer hidden layers.
num_layers: The number of transformer layers. num_layers: The number of transformer layers.
num_attention_heads: The number of attention heads for each transformer. The num_attention_heads: The number of attention heads for each transformer. The
......
...@@ -69,9 +69,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder): ...@@ -69,9 +69,9 @@ class BertEncoder(keras_nlp.encoders.BertEncoder):
output. output.
embedding_width: The width of the word embeddings. If the embedding width is embedding_width: The width of the word embeddings. If the embedding width is
not equal to hidden size, embedding parameters will be factorized into two not equal to hidden size, embedding parameters will be factorized into two
matrices in the shape of ['vocab_size', 'embedding_width'] and matrices in the shape of `(vocab_size, embedding_width)` and
['embedding_width', 'hidden_size'] ('embedding_width' is usually much `(embedding_width, hidden_size)`, where `embedding_width` is usually much
smaller than 'hidden_size'). smaller than `hidden_size`.
embedding_layer: The word embedding layer. `None` means we will create a new embedding_layer: The word embedding layer. `None` means we will create a new
embedding layer. Otherwise, we will reuse the given embedding layer. This embedding layer. Otherwise, we will reuse the given embedding layer. This
parameter is originally added for ELECTRA model which needs to tie the parameter is originally added for ELECTRA model which needs to tie the
......
...@@ -35,8 +35,8 @@ class Classification(tf.keras.Model): ...@@ -35,8 +35,8 @@ class Classification(tf.keras.Model):
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
initializer: The initializer for the dense layer in this network. Defaults initializer: The initializer for the dense layer in this network. Defaults
to a Glorot uniform initializer. to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
""" """
def __init__(self, def __init__(self,
......
...@@ -38,7 +38,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -38,7 +38,7 @@ class EncoderScaffold(tf.keras.Model):
class (which will replace the Transformer instantiation in the encoder). For class (which will replace the Transformer instantiation in the encoder). For
each of these custom injection points, users can pass either a class or a each of these custom injection points, users can pass either a class or a
class instance. If a class is passed, that class will be instantiated using class instance. If a class is passed, that class will be instantiated using
the 'embedding_cfg' or 'hidden_cfg' argument, respectively; if an instance the `embedding_cfg` or `hidden_cfg` argument, respectively; if an instance
is passed, that instance will be invoked. (In the case of hidden_cls, the is passed, that instance will be invoked. (In the case of hidden_cls, the
instance will be invoked 'num_hidden_instances' times. instance will be invoked 'num_hidden_instances' times.
...@@ -53,40 +53,41 @@ class EncoderScaffold(tf.keras.Model): ...@@ -53,40 +53,41 @@ class EncoderScaffold(tf.keras.Model):
pooler_layer_initializer: The initializer for the classification layer. pooler_layer_initializer: The initializer for the classification layer.
embedding_cls: The class or instance to use to embed the input data. This embedding_cls: The class or instance to use to embed the input data. This
class or instance defines the inputs to this encoder and outputs (1) class or instance defines the inputs to this encoder and outputs (1)
embeddings tensor with shape [batch_size, seq_length, hidden_size] and (2) embeddings tensor with shape `(batch_size, seq_length, hidden_size)` and
attention masking with tensor [batch_size, seq_length, seq_length]. If (2) attention masking with tensor `(batch_size, seq_length, seq_length)`.
embedding_cls is not set, a default embedding network (from the original If `embedding_cls` is not set, a default embedding network (from the
BERT paper) will be created. original BERT paper) will be created.
embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to embedding_cfg: A dict of kwargs to pass to the embedding_cls, if it needs to
be instantiated. If embedding_cls is not set, a config dict must be be instantiated. If `embedding_cls` is not set, a config dict must be
passed to 'embedding_cfg' with the following values: passed to `embedding_cfg` with the following values:
"vocab_size": The size of the token vocabulary. `vocab_size`: The size of the token vocabulary.
"type_vocab_size": The size of the type vocabulary. `type_vocab_size`: The size of the type vocabulary.
"hidden_size": The hidden size for this encoder. `hidden_size`: The hidden size for this encoder.
"max_seq_length": The maximum sequence length for this encoder. `max_seq_length`: The maximum sequence length for this encoder.
"seq_length": The sequence length for this encoder. `seq_length`: The sequence length for this encoder.
"initializer": The initializer for the embedding portion of this encoder. `initializer`: The initializer for the embedding portion of this encoder.
"dropout_rate": The dropout rate to apply before the encoding layers. `dropout_rate`: The dropout rate to apply before the encoding layers.
embedding_data: A reference to the embedding weights that will be used to embedding_data: A reference to the embedding weights that will be used to
train the masked language model, if necessary. This is optional, and only train the masked language model, if necessary. This is optional, and only
needed if (1) you are overriding embedding_cls and (2) are doing standard needed if (1) you are overriding `embedding_cls` and (2) are doing
pretraining. standard pretraining.
num_hidden_instances: The number of times to instantiate and/or invoke the num_hidden_instances: The number of times to instantiate and/or invoke the
hidden_cls. hidden_cls.
hidden_cls: The class or instance to encode the input data. If hidden_cls is hidden_cls: The class or instance to encode the input data. If `hidden_cls`
not set, a KerasBERT transformer layer will be used as the encoder class. is not set, a KerasBERT transformer layer will be used as the encoder
class.
hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be hidden_cfg: A dict of kwargs to pass to the hidden_cls, if it needs to be
instantiated. If hidden_cls is not set, a config dict must be passed to instantiated. If hidden_cls is not set, a config dict must be passed to
'hidden_cfg' with the following values: `hidden_cfg` with the following values:
"num_attention_heads": The number of attention heads. The hidden size `num_attention_heads`: The number of attention heads. The hidden size
must be divisible by num_attention_heads. must be divisible by `num_attention_heads`.
"intermediate_size": The intermediate size of the transformer. `intermediate_size`: The intermediate size of the transformer.
"intermediate_activation": The activation to apply in the transfomer. `intermediate_activation`: The activation to apply in the transfomer.
"dropout_rate": The overall dropout rate for the transformer layers. `dropout_rate`: The overall dropout rate for the transformer layers.
"attention_dropout_rate": The dropout rate for the attention layers. `attention_dropout_rate`: The dropout rate for the attention layers.
"kernel_initializer": The initializer for the transformer layers. `kernel_initializer`: The initializer for the transformer layers.
layer_norm_before_pooling: Whether to add a layer norm before the pooling layer_norm_before_pooling: Whether to add a layer norm before the pooling
layer. You probably want to turn this on if you set norm_first=True in layer. You probably want to turn this on if you set `norm_first=True` in
transformer layers. transformer layers.
return_all_layer_outputs: Whether to output sequence embedding outputs of return_all_layer_outputs: Whether to output sequence embedding outputs of
all encoder transformer layers. all encoder transformer layers.
......
...@@ -63,7 +63,7 @@ class MobileBERTEncoder(tf.keras.Model): ...@@ -63,7 +63,7 @@ class MobileBERTEncoder(tf.keras.Model):
attention_probs_dropout_prob: Dropout probability of the attention attention_probs_dropout_prob: Dropout probability of the attention
probabilities. probabilities.
intra_bottleneck_size: Size of bottleneck. intra_bottleneck_size: Size of bottleneck.
initializer_range: The stddev of the truncated_normal_initializer for initializer_range: The stddev of the `truncated_normal_initializer` for
initializing all weight matrices. initializing all weight matrices.
use_bottleneck_attention: Use attention inputs from the bottleneck use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck` transformation. If true, the following `key_query_shared_bottleneck`
...@@ -71,17 +71,17 @@ class MobileBERTEncoder(tf.keras.Model): ...@@ -71,17 +71,17 @@ class MobileBERTEncoder(tf.keras.Model):
key_query_shared_bottleneck: Whether to share linear transformation for key_query_shared_bottleneck: Whether to share linear transformation for
keys and queries. keys and queries.
num_feedforward_networks: Number of stacked feed-forward networks. num_feedforward_networks: Number of stacked feed-forward networks.
normalization_type: The type of normalization_type, only 'no_norm' and normalization_type: The type of normalization_type, only `no_norm` and
'layer_norm' are supported. 'no_norm' represents the element-wise linear `layer_norm` are supported. `no_norm` represents the element-wise linear
transformation for the student model, as suggested by the original transformation for the student model, as suggested by the original
MobileBERT paper. 'layer_norm' is used for the teacher model. MobileBERT paper. `layer_norm` is used for the teacher model.
classifier_activation: If using the tanh activation for the final classifier_activation: If using the tanh activation for the final
representation of the [CLS] token in fine-tuning. representation of the `[CLS]` token in fine-tuning.
input_mask_dtype: The dtype of `input_mask` tensor, which is one of the input_mask_dtype: The dtype of `input_mask` tensor, which is one of the
input tensors of this encoder. Defaults to `int32`. If you want input tensors of this encoder. Defaults to `int32`. If you want
to use `tf.lite` quantization, which does not support `Cast` op, to use `tf.lite` quantization, which does not support `Cast` op,
please set this argument to `tf.float32` and feed `input_mask` please set this argument to `tf.float32` and feed `input_mask`
tensor with values in float32 to avoid `tf.cast` in the computation. tensor with values in `float32` to avoid `tf.cast` in the computation.
**kwargs: Other keyworded and arguments. **kwargs: Other keyworded and arguments.
""" """
self._self_setattr_tracking = False self._self_setattr_tracking = False
......
...@@ -160,6 +160,8 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase): ...@@ -160,6 +160,8 @@ class MobileBertEncoderTest(parameterized.TestCase, tf.test.TestCase):
mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) mask = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32) type_ids = tf.keras.Input(shape=(sequence_length,), dtype=tf.int32)
prediction = classifier([word_ids, mask, type_ids]) prediction = classifier([word_ids, mask, type_ids])
if task == models.BertTokenClassifier:
prediction = prediction['logits']
self.assertAllEqual(prediction.shape.as_list(), prediction_shape) self.assertAllEqual(prediction.shape.as_list(), prediction_shape)
......
...@@ -40,14 +40,14 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -40,14 +40,14 @@ class PackedSequenceEmbedding(tf.keras.Model):
max_seq_length: The maximum sequence length for this encoder. max_seq_length: The maximum sequence length for this encoder.
initializer: The initializer for the embedding portion of this encoder. initializer: The initializer for the embedding portion of this encoder.
dropout_rate: The dropout rate to apply before the encoding layers. dropout_rate: The dropout rate to apply before the encoding layers.
pack_multiple_sequences: If True, we can feed multiple sequences into one pack_multiple_sequences: If `True`, we can feed multiple sequences into one
sequence for training and inference (they don't impact each other). sequence for training and inference (they don't impact each other).
use_position_id: Whether to expect `position_ids` as an input to the use_position_id: Whether to expect `position_ids` as an input to the
network. If False, the `position_ids` will be inferred: (1) when network. If False, the `position_ids` will be inferred: (1) when
pack_multiple_sequences is False, we assume the position ids are 0, 1, pack_multiple_sequences is False, we assume the position ids are `0, 1,
2, ..., seq_length - 1; (2) when pack_multiple_sequences is True, there 2, ..., seq_length - 1`; (2) when `pack_multiple_sequences` is `True`,
may be multiple sub sequences, and for each sub sequence, its position there may be multiple sub sequences, and for each sub sequence, its
ids start from 0, 1, 2, ... position ids start from 0, 1, 2, ...
""" """
def __init__(self, def __init__(self,
......
...@@ -37,8 +37,8 @@ class SpanLabeling(tf.keras.Model): ...@@ -37,8 +37,8 @@ class SpanLabeling(tf.keras.Model):
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
initializer: The initializer for the dense layer in this network. Defaults initializer: The initializer for the dense layer in this network. Defaults
to a Glorot uniform initializer. to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either `logits` or
'predictions'. `predictions`.
""" """
def __init__(self, def __init__(self,
...@@ -228,20 +228,20 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -228,20 +228,20 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
Args: Args:
sequence_data: The input sequence data of shape sequence_data: The input sequence data of shape
(batch_size, seq_length, input_width). `(batch_size, seq_length, input_width)`.
class_index: The class indices of the inputs of shape (batch_size,). class_index: The class indices of the inputs of shape `(batch_size,)`.
paragraph_mask: Invalid position mask such as query and special symbols paragraph_mask: Invalid position mask such as query and special symbols
(e.g. PAD, SEP, CLS) of shape (batch_size,). (e.g. PAD, SEP, CLS) of shape `(batch_size,)`.
start_positions: The start positions of each example of shape start_positions: The start positions of each example of shape
(batch_size,). `(batch_size,)`.
training: Whether or not this is the training phase. training: Whether or not this is the training phase.
Returns: Returns:
A dictionary with the keys 'start_predictions', 'end_predictions', A dictionary with the keys `start_predictions`, `end_predictions`,
'start_logits', 'end_logits'. `start_logits`, `end_logits`.
If inference, then 'start_top_predictions', 'start_top_index', If inference, then `start_top_predictions`, `start_top_index`,
'end_top_predictions', 'end_top_index' are also included. `end_top_predictions`, `end_top_index` are also included.
""" """
paragraph_mask = tf.cast(paragraph_mask, dtype=sequence_data.dtype) paragraph_mask = tf.cast(paragraph_mask, dtype=sequence_data.dtype)
......
...@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Tuple ...@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Tuple
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from official.modeling import tf_utils
Output = Tuple[tf.Tensor, tf.Tensor] Output = Tuple[tf.Tensor, tf.Tensor]
InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict] InternalState = Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Dict]
...@@ -64,15 +65,7 @@ def log_prob_from_logits(logits): ...@@ -64,15 +65,7 @@ def log_prob_from_logits(logits):
def shape_list(tensor): def shape_list(tensor):
"""Return a list of the tensor's shape, and ensure no None values in list.""" """Return a list of the tensor's shape, and ensure no None values in list."""
# Get statically known shape (may contain None's for unknown dimensions) return tf_utils.get_shape_list(tensor)
shape = tensor.get_shape().as_list()
# Ensure that the shape values are not None
dynamic_shape = tf.shape(tensor)
for i in range(len(shape)): # pylint: disable=consider-using-enumerate
if shape[i] is None:
shape[i] = dynamic_shape[i]
return shape
def get_shape_keep_last_dim(tensor): def get_shape_keep_last_dim(tensor):
......
...@@ -76,14 +76,15 @@ def sample_top_p(logits, top_p): ...@@ -76,14 +76,15 @@ def sample_top_p(logits, top_p):
""" """
sorted_indices = tf.argsort(logits, direction="DESCENDING") sorted_indices = tf.argsort(logits, direction="DESCENDING")
# Flatten logits as tf.gather on TPU needs axis to be compile time constant. # Flatten logits as tf.gather on TPU needs axis to be compile time constant.
range_for_gather = tf.expand_dims(tf.range(0, logits.shape[0]), axis=1) logits_shape = decoding_module.shape_list(logits)
range_for_gather = tf.tile(range_for_gather * logits.shape[1], range_for_gather = tf.expand_dims(tf.range(0, logits_shape[0]), axis=1)
[1, logits.shape[1]]) + sorted_indices range_for_gather = tf.tile(range_for_gather * logits_shape[1],
[1, logits_shape[1]]) + sorted_indices
flattened_logits = tf.reshape(logits, [-1]) flattened_logits = tf.reshape(logits, [-1])
flattened_sorted_indices = tf.reshape(range_for_gather, [-1]) flattened_sorted_indices = tf.reshape(range_for_gather, [-1])
sorted_logits = tf.reshape( sorted_logits = tf.reshape(
tf.gather(flattened_logits, flattened_sorted_indices), tf.gather(flattened_logits, flattened_sorted_indices),
[logits.shape[0], logits.shape[1]]) [logits_shape[0], logits_shape[1]])
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
# Remove tokens with cumulative probability above the threshold. # Remove tokens with cumulative probability above the threshold.
......
...@@ -113,7 +113,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -113,7 +113,7 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
correct way of using L2 regularization/weight decay with Adam, since that will correct way of using L2 regularization/weight decay with Adam, since that will
interact with the m and v parameters in strange ways. interact with the m and v parameters in strange ways.
Instead we want ot decay the weights in a manner that doesn't interact with Instead we want to decay the weights in a manner that doesn't interact with
the m/v parameters. This is equivalent to adding the square of the weights to the m/v parameters. This is equivalent to adding the square of the weights to
the loss with plain (non-momentum) SGD. the loss with plain (non-momentum) SGD.
""" """
...@@ -171,7 +171,8 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -171,7 +171,8 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
# and passed the allreduced grads_and_vars. For now, the # and passed the allreduced grads_and_vars. For now, the
# clip_by_global_norm will be moved to before the explicit allreduce to # clip_by_global_norm will be moved to before the explicit allreduce to
# keep the math the same as TF 1 and pre TF 2.2 implementation. # keep the math the same as TF 1 and pre TF 2.2 implementation.
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) (grads, _) = tf.clip_by_global_norm(
grads, clip_norm=self.gradient_clip_norm)
return super(AdamWeightDecay, self).apply_gradients( return super(AdamWeightDecay, self).apply_gradients(
zip(grads, tvars), zip(grads, tvars),
name=name, name=name,
......
...@@ -98,13 +98,14 @@ class TaggingTask(base_task.Task): ...@@ -98,13 +98,14 @@ class TaggingTask(base_task.Task):
initializer=tf.keras.initializers.TruncatedNormal( initializer=tf.keras.initializers.TruncatedNormal(
stddev=self.task_config.model.head_initializer_range), stddev=self.task_config.model.head_initializer_range),
dropout_rate=self.task_config.model.head_dropout, dropout_rate=self.task_config.model.head_dropout,
output='logits') output='logits',
output_encoder_outputs=True)
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor: def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
model_outputs = tf.cast(model_outputs, tf.float32) logits = tf.cast(model_outputs['logits'], tf.float32)
masked_labels, masked_weights = _masked_labels_and_weights(labels) masked_labels, masked_weights = _masked_labels_and_weights(labels)
loss = tf.keras.losses.sparse_categorical_crossentropy( loss = tf.keras.losses.sparse_categorical_crossentropy(
masked_labels, model_outputs, from_logits=True) masked_labels, logits, from_logits=True)
numerator_loss = tf.reduce_sum(loss * masked_weights) numerator_loss = tf.reduce_sum(loss * masked_weights)
denominator_loss = tf.reduce_sum(masked_weights) denominator_loss = tf.reduce_sum(masked_weights)
loss = tf.math.divide_no_nan(numerator_loss, denominator_loss) loss = tf.math.divide_no_nan(numerator_loss, denominator_loss)
...@@ -139,7 +140,7 @@ class TaggingTask(base_task.Task): ...@@ -139,7 +140,7 @@ class TaggingTask(base_task.Task):
def inference_step(self, inputs, model: tf.keras.Model): def inference_step(self, inputs, model: tf.keras.Model):
"""Performs the forward step.""" """Performs the forward step."""
logits = model(inputs, training=False) logits = model(inputs, training=False)['logits']
return {'logits': logits, return {'logits': logits,
'predict_ids': tf.argmax(logits, axis=-1, output_type=tf.int32)} 'predict_ids': tf.argmax(logits, axis=-1, output_type=tf.int32)}
...@@ -156,7 +157,7 @@ class TaggingTask(base_task.Task): ...@@ -156,7 +157,7 @@ class TaggingTask(base_task.Task):
""" """
features, labels = inputs features, labels = inputs
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
loss = self.build_losses(labels=labels, model_outputs=outputs['logits']) loss = self.build_losses(labels=labels, model_outputs=outputs)
# Negative label ids are padding labels which should be ignored. # Negative label ids are padding labels which should be ignored.
real_label_index = tf.where(tf.greater_equal(labels, 0)) real_label_index = tf.where(tf.greater_equal(labels, 0))
......
...@@ -302,7 +302,6 @@ class TranslationTask(base_task.Task): ...@@ -302,7 +302,6 @@ class TranslationTask(base_task.Task):
logs = {self.loss: loss} logs = {self.loss: loss}
if metrics: if metrics:
self.process_metrics(metrics, inputs["targets"], outputs) self.process_metrics(metrics, inputs["targets"], outputs)
logs.update({m.name: m.result() for m in metrics})
return logs return logs
def validation_step(self, inputs, model: tf.keras.Model, metrics=None): def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
......
...@@ -51,7 +51,7 @@ class MockTask(base_task.Task): ...@@ -51,7 +51,7 @@ class MockTask(base_task.Task):
def build_model(self, *arg, **kwargs): def build_model(self, *arg, **kwargs):
inputs = tf.keras.layers.Input(shape=(2,), name="random", dtype=tf.float32) inputs = tf.keras.layers.Input(shape=(2,), name="random", dtype=tf.float32)
outputs = tf.keras.layers.Dense( outputs = tf.keras.layers.Dense(
1, bias_initializer=tf.keras.initializers.Ones())( 1, bias_initializer=tf.keras.initializers.Ones(), name="dense_0")(
inputs) inputs)
network = tf.keras.Model(inputs=inputs, outputs=outputs) network = tf.keras.Model(inputs=inputs, outputs=outputs)
return MockModel(network) return MockModel(network)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment