Unverified Commit 0cceabfc authored by Yiming Shi's avatar Yiming Shi Committed by GitHub
Browse files

Merge branch 'master' into move_to_keraslayers_fasterrcnn_fpn_keras_feature_extractor

parents 17821c0d 39ee0ac9
...@@ -34,7 +34,6 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -34,7 +34,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
def create_layer(self, def create_layer(self,
vocab_size, vocab_size,
sequence_length,
hidden_size, hidden_size,
output='predictions', output='predictions',
xformer_stack=None): xformer_stack=None):
...@@ -44,7 +43,6 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -44,7 +43,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
xformer_stack = transformer_encoder.TransformerEncoder( xformer_stack = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=1, num_layers=1,
sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=4, num_attention_heads=4,
) )
...@@ -62,7 +60,6 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -62,7 +60,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
num_predictions = 21 num_predictions = 21
test_layer = self.create_layer( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size) hidden_size=hidden_size)
# Make sure that the output tensor of the masked LM is the right shape. # Make sure that the output tensor of the masked LM is the right shape.
...@@ -81,19 +78,16 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -81,19 +78,16 @@ class MaskedLMTest(keras_parameterized.TestCase):
xformer_stack = transformer_encoder.TransformerEncoder( xformer_stack = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=1, num_layers=1,
sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
num_attention_heads=4, num_attention_heads=4,
) )
test_layer = self.create_layer( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
xformer_stack=xformer_stack, xformer_stack=xformer_stack,
output='predictions') output='predictions')
logit_layer = self.create_layer( logit_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size, hidden_size=hidden_size,
xformer_stack=xformer_stack, xformer_stack=xformer_stack,
output='logits') output='logits')
...@@ -134,7 +128,6 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -134,7 +128,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
num_predictions = 21 num_predictions = 21
test_layer = self.create_layer( test_layer = self.create_layer(
vocab_size=vocab_size, vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size) hidden_size=hidden_size)
# Create a model from the masked LM layer. # Create a model from the masked LM layer.
...@@ -155,7 +148,7 @@ class MaskedLMTest(keras_parameterized.TestCase): ...@@ -155,7 +148,7 @@ class MaskedLMTest(keras_parameterized.TestCase):
def test_unknown_output_type_fails(self): def test_unknown_output_type_fails(self):
with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'): with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
_ = self.create_layer( _ = self.create_layer(
vocab_size=8, sequence_length=8, hidden_size=8, output='bad') vocab_size=8, hidden_size=8, output='bad')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Multi-channel decoder.""" """Multi-channel Attention."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -24,11 +25,24 @@ import math ...@@ -24,11 +25,24 @@ import math
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import masked_softmax
class DocAttention(tf.keras.layers.Layer): class VotingAttention(tf.keras.layers.Layer):
"""Documents Attention layer.""" """Voting Attention layer.
Arguments:
num_heads: the number of attention heads.
head_size: per-head hidden size.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def __init__(self, def __init__(self,
num_heads, num_heads,
...@@ -41,7 +55,7 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -41,7 +55,7 @@ class DocAttention(tf.keras.layers.Layer):
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
**kwargs): **kwargs):
super(DocAttention, self).__init__(**kwargs) super(VotingAttention, self).__init__(**kwargs)
self._num_heads = num_heads self._num_heads = num_heads
self._head_size = head_size self._head_size = head_size
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
...@@ -52,29 +66,27 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -52,29 +66,27 @@ class DocAttention(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
self._query_dense = layers.DenseEinsum( common_kwargs = dict(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
dtype=self.dtype, self._query_dense = tf.keras.layers.experimental.EinsumDense(
name="encdocatt_query") "BAE,ENH->BANH",
self._key_dense = layers.DenseEinsum( output_shape=(None, self._num_heads, self._head_size),
output_shape=(self._num_heads, self._head_size), bias_axes="NH",
kernel_initializer=self._kernel_initializer, name="query",
bias_initializer=self._bias_initializer, **common_kwargs)
kernel_regularizer=self._kernel_regularizer, self._key_dense = tf.keras.layers.experimental.EinsumDense(
bias_regularizer=self._bias_regularizer, "BAE,ENH->BANH",
activity_regularizer=self._activity_regularizer, output_shape=(None, self._num_heads, self._head_size),
kernel_constraint=self._kernel_constraint, bias_axes="NH",
bias_constraint=self._bias_constraint, name="key",
dtype=self.dtype, **common_kwargs)
name="encdocatt_key") super(VotingAttention, self).build(unused_input_shapes)
super(DocAttention, self).build(unused_input_shapes)
def call(self, encoder_outputs, doc_attention_mask): def call(self, encoder_outputs, doc_attention_mask):
num_docs = tf_utils.get_shape_list(encoder_outputs, expected_rank=[4])[1] num_docs = tf_utils.get_shape_list(encoder_outputs, expected_rank=[4])[1]
...@@ -95,33 +107,55 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -95,33 +107,55 @@ class DocAttention(tf.keras.layers.Layer):
return tf.nn.softmax(doc_attention_probs + infadder) return tf.nn.softmax(doc_attention_probs + infadder)
class MultiChannelAttention(layers.MultiHeadAttention): class MultiChannelAttention(attention.MultiHeadAttention):
"""Multi-channel Attention layer.""" """Multi-channel Attention layer.
def build(self, input_shape): Introduced in, [Generating Representative Headlines for News Stories
super(MultiChannelAttention, self).build(input_shape) ](https://arxiv.org/abs/2001.09386). Expects multiple cross-attention
self._masked_softmax = layers.MaskedSoftmax(mask_expansion_axes=[2]) target sequences.
def call(self, inputs, attention_mask=None): Call args:
from_tensor = inputs[0] query: Query `Tensor` of shape `[B, T, dim]`.
to_tensor = inputs[1] value: Value `Tensor` of shape `[B, A, S, dim]`, where A denotes the
doc_attention_probs = inputs[2] context_attention_weights: Context weights of shape `[B, N, T, A]`, where N
is the number of attention heads. Combines multi-channel sources
context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def build_attention(self, rank):
super(MultiChannelAttention, self).build_attention(rank)
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self,
query,
value,
key=None,
context_attention_weights=None,
attention_mask=None):
if not self._built_from_signature:
self._build_from_signature(query, value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here: # Scalar dimensions referenced here:
# B = batch size (number of stories) # B = batch size (number of stories)
# A = num_docs (number of docs) # A = num_docs (number of docs)
# F = `from_tensor` sequence length # F = target sequence length
# T = `to_tensor` sequence length # T = source sequence length
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, F, N ,H] # `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor) query_tensor = self._query_dense(query)
# `key_tensor` = [B, A, T, N, H] # `key_tensor` = [B, A, T, N, H]
key_tensor = self._key_dense(to_tensor) key_tensor = self._key_dense(key)
# `value_tensor` = [B, A, T, N, H] # `value_tensor` = [B, A, T, N, H]
value_tensor = self._value_dense(to_tensor) value_tensor = self._value_dense(value)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
...@@ -140,7 +174,7 @@ class MultiChannelAttention(layers.MultiHeadAttention): ...@@ -140,7 +174,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs, context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs,
value_tensor) value_tensor)
attention_output = tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs, attention_output = tf.einsum("BNFA,BAFNH->BFNH", context_attention_weights,
context_layer) context_layer)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
return attention_output return attention_output
...@@ -22,14 +22,15 @@ from __future__ import print_function ...@@ -22,14 +22,15 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.nlp.nhnet import multi_channel_attention from official.nlp.modeling.layers import multi_channel_attention
class MultiChannelAttentionTest(tf.test.TestCase): class MultiChannelAttentionTest(tf.test.TestCase):
def test_doc_attention(self): def test_doc_attention(self):
num_heads = 2 num_heads = 2
doc_attention = multi_channel_attention.DocAttention(num_heads, head_size=8) doc_attention = multi_channel_attention.VotingAttention(
num_heads, head_size=8)
num_docs = 3 num_docs = 3
inputs = np.zeros((2, num_docs, 10, 16), dtype=np.float32) inputs = np.zeros((2, num_docs, 10, 16), dtype=np.float32)
doc_mask = np.zeros((2, num_docs), dtype=np.float32) doc_mask = np.zeros((2, num_docs), dtype=np.float32)
...@@ -47,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase): ...@@ -47,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
mask_data = np.random.randint(2, size=(3, num_docs, 4, 2)) mask_data = np.random.randint(2, size=(3, num_docs, 4, 2))
doc_probs = np.random.randint( doc_probs = np.random.randint(
2, size=(3, num_heads, 4, num_docs)).astype(float) 2, size=(3, num_heads, 4, num_docs)).astype(float)
outputs = attention_layer([from_data, to_data, doc_probs], mask_data) outputs = attention_layer(
query=from_data,
value=to_data,
context_attention_weights=doc_probs,
attention_mask=mask_data)
self.assertEqual(outputs.shape, (3, 4, 8)) self.assertEqual(outputs.shape, (3, 4, 8))
......
...@@ -38,6 +38,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -38,6 +38,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
lookup. Defaults to False (that is, using tf.gather). Setting this option lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory. will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that
is, not to scale). Setting this option to True will let values in output
embeddings multiplied by self._embedding_width ** 0.5.
""" """
def __init__(self, def __init__(self,
...@@ -45,6 +48,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -45,6 +48,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embedding_width, embedding_width,
initializer="glorot_uniform", initializer="glorot_uniform",
use_one_hot=False, use_one_hot=False,
use_scale=False,
**kwargs): **kwargs):
super(OnDeviceEmbedding, self).__init__(**kwargs) super(OnDeviceEmbedding, self).__init__(**kwargs)
...@@ -52,6 +56,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -52,6 +56,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
self._embedding_width = embedding_width self._embedding_width = embedding_width
self._initializer = initializer self._initializer = initializer
self._use_one_hot = use_one_hot self._use_one_hot = use_one_hot
self._use_scale = use_scale
def get_config(self): def get_config(self):
config = { config = {
...@@ -59,6 +64,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -59,6 +64,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"embedding_width": self._embedding_width, "embedding_width": self._embedding_width,
"initializer": self._initializer, "initializer": self._initializer,
"use_one_hot": self._use_one_hot, "use_one_hot": self._use_one_hot,
"use_scale": self._use_scale,
} }
base_config = super(OnDeviceEmbedding, self).get_config() base_config = super(OnDeviceEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -85,4 +91,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer): ...@@ -85,4 +91,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
# Work around b/142213824: prefer concat to shape over a Python list. # Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0)) tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width]) embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
if self._use_scale:
embeddings *= self._embedding_width ** 0.5
return embeddings return embeddings
...@@ -193,6 +193,26 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase): ...@@ -193,6 +193,26 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
output = model.predict(input_data) output = model.predict(input_data)
self.assertEqual(tf.float16, output.dtype) self.assertEqual(tf.float16, output.dtype)
def test_use_scale_layer_invocation(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, use_scale=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer): ...@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
"hidden_size": self._hidden_size, "hidden_size": self._hidden_size,
"min_timescale": self._min_timescale, "min_timescale": self._min_timescale,
"max_timescale": self._max_timescale, "max_timescale": self._max_timescale,
"length": self._length,
} }
base_config = super(RelativePositionEmbedding, self).get_config() base_config = super(RelativePositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
......
...@@ -23,7 +23,6 @@ import gin ...@@ -23,7 +23,6 @@ import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention " "The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
name="self_attention") self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
...@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=self._intermediate_size, "abc,cd->abd",
activation=None, output_shape=(None, self._intermediate_size),
kernel_initializer=self._kernel_initializer, bias_axes="d",
bias_initializer=self._bias_initializer, name="intermediate",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge # bfloat16 causes BERT with the LAMB optimizer to not converge
...@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._intermediate_activation, dtype=policy)
self._output_dense = dense_einsum.DenseEinsum( self._output_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=hidden_size, "abc,cd->abd",
kernel_initializer=self._kernel_initializer, output_shape=(None, hidden_size),
bias_initializer=self._bias_initializer, bias_axes="d",
kernel_regularizer=self._kernel_regularizer, name="output",
bias_regularizer=self._bias_regularizer, **common_kwargs)
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
...@@ -222,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -222,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
else: else:
target_tensor = input_tensor target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask) attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = target_tensor + self._rezero_a * attention_output attention_output = target_tensor + self._rezero_a * attention_output
if self._use_layer_norm: if self._use_layer_norm:
......
...@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
""" """
def _build_attention(self, qkv_rank): def build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations. """Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection This function overrides base class to create additional linear projection
...@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Args: Args:
qkv_rank: the rank of query, key, value tensors after projection. qkv_rank: the rank of query, key, value tensors after projection.
""" """
super(TalkingHeadsAttention, self)._build_attention(qkv_rank) super(TalkingHeadsAttention, self).build_attention(qkv_rank)
# Build an equation: # Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) -> # (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
...@@ -103,11 +103,11 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -103,11 +103,11 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
dtype=self.dtype, dtype=self.dtype,
trainable=True) trainable=True)
def _compute_attention(self, def compute_attention(self,
query_tensor, query_tensor,
key_tensor, key_tensor,
value_tensor, value_tensor,
attention_mask=None): attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors. """Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection This function overrides base class to apply additional linear projection
......
...@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80)) value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value]) output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims) self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
...@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64) num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self): def test_attention_scores(self):
...@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True) num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query]) output, coef = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40]) self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
...@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8)) query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8)) value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor) output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer. # Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output) model = tf.keras.Model([query, value, mask_tensor], output)
...@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V. # Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8)) key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor) output = test_layer(
query=query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output) model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data]) masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
...@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters( @parameterized.named_parameters(
...@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least # Invoke the data with a random set of mask data. This should mask at least
# one element. # one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool") mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data) output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked). # Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape) null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data) unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the # Because one data is masked and one is not, the outputs should not be the
# same. # same.
self.assertNotAllClose(output, unmasked_output) self.assertNotAllClose(output, unmasked_output)
......
...@@ -23,7 +23,7 @@ import gin ...@@ -23,7 +23,7 @@ import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers.util import tf_function_if_eager from official.nlp.modeling.layers.util import tf_function_if_eager
...@@ -49,6 +49,12 @@ class Transformer(tf.keras.layers.Layer): ...@@ -49,6 +49,12 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity. activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels. kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
""" """
def __init__(self, def __init__(self,
...@@ -65,6 +71,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -65,6 +71,9 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=None, activity_regularizer=None,
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
**kwargs): **kwargs):
super(Transformer, self).__init__(**kwargs) super(Transformer, self).__init__(**kwargs)
...@@ -78,8 +87,12 @@ class Transformer(tf.keras.layers.Layer): ...@@ -78,8 +87,12 @@ class Transformer(tf.keras.layers.Layer):
self._bias_initializer = tf.keras.initializers.get(bias_initializer) self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer) self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint) self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
def build(self, input_shape): def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
...@@ -104,23 +117,21 @@ class Transformer(tf.keras.layers.Layer): ...@@ -104,23 +117,21 @@ class Transformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention " "The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
name="self_attention") self._attention_layer = attention.MultiHeadAttention(
# pylint: disable=protected-access num_heads=self._num_heads,
self._attention_layer.build([input_tensor_shape] * 3) key_size=self._attention_head_size,
self._attention_output_dense = self._attention_layer._output_dense dropout=self._attention_dropout_rate,
# pylint: enable=protected-access use_bias=self._use_bias,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet. # It is probably safe in mixed_float16, but we haven't validated this yet.
...@@ -128,19 +139,14 @@ class Transformer(tf.keras.layers.Layer): ...@@ -128,19 +139,14 @@ class Transformer(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm", name="self_attention_layer_norm",
axis=-1, axis=-1,
epsilon=1e-12, epsilon=self._norm_epsilon,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=self._intermediate_size, "abc,cd->abd",
activation=None, output_shape=(None, self._intermediate_size),
kernel_initializer=self._kernel_initializer, bias_axes="d",
bias_initializer=self._bias_initializer, name="intermediate",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge # bfloat16 causes BERT with the LAMB optimizer to not converge
...@@ -149,20 +155,19 @@ class Transformer(tf.keras.layers.Layer): ...@@ -149,20 +155,19 @@ class Transformer(tf.keras.layers.Layer):
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._intermediate_activation, dtype=policy)
self._output_dense = dense_einsum.DenseEinsum( self._output_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=hidden_size, "abc,cd->abd",
kernel_initializer=self._kernel_initializer, output_shape=(None, hidden_size),
bias_initializer=self._bias_initializer, bias_axes="d",
kernel_regularizer=self._kernel_regularizer, name="output",
bias_regularizer=self._bias_regularizer, **common_kwargs)
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32) name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
super(Transformer, self).build(input_shape) super(Transformer, self).build(input_shape)
...@@ -193,7 +198,13 @@ class Transformer(tf.keras.layers.Layer): ...@@ -193,7 +198,13 @@ class Transformer(tf.keras.layers.Layer):
"kernel_constraint": "kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint), tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint": "bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint) tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon
} }
base_config = super(Transformer, self).get_config() base_config = super(Transformer, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
...@@ -208,13 +219,22 @@ class Transformer(tf.keras.layers.Layer): ...@@ -208,13 +219,22 @@ class Transformer(tf.keras.layers.Layer):
target_tensor = input_tensor[:, 0:self._output_range, :] target_tensor = input_tensor[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
else: else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
target_tensor = input_tensor target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask) attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(target_tensor + if self._norm_first:
attention_output) attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer( intermediate_output = self._intermediate_activation_layer(
intermediate_output) intermediate_output)
...@@ -224,7 +244,10 @@ class Transformer(tf.keras.layers.Layer): ...@@ -224,7 +244,10 @@ class Transformer(tf.keras.layers.Layer):
# is always fp32 for now. Cast layer_output to fp32 for the subsequent # is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add. # add.
layer_output = tf.cast(layer_output, tf.float32) layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output) if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output return layer_output
...@@ -236,3 +259,259 @@ class CompiledTransformer(Transformer): ...@@ -236,3 +259,259 @@ class CompiledTransformer(Transformer):
@tf_function_if_eager(experimental_compile=True) @tf_function_if_eager(experimental_compile=True)
def call(self, inputs): def call(self, inputs):
return super(CompiledTransformer, self).call(inputs) return super(CompiledTransformer, self).call(inputs)
@tf.keras.utils.register_keras_serializable(package="Text")
class TransformerDecoderLayer(tf.keras.layers.Layer):
"""Single transformer layer for decoder.
It has three sub-layers:
(1) a multi-head self-attention mechanism.
(2) a encoder-decoder attention.
(3) a positionwise fully connected feed-forward network.
Arguments:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
multi_channel_cross_attention: Whether to use `MultiChannelAttention` for
cross-attention between target sequences and source sequences.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
"""
def __init__(self,
num_attention_heads,
intermediate_size,
intermediate_activation,
dropout_rate=0.0,
attention_dropout_rate=0.0,
multi_channel_cross_attention=False,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
**kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs)
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = tf.keras.activations.get(
intermediate_activation)
self.dropout_rate = dropout_rate
self.attention_dropout_rate = attention_dropout_rate
self.multi_channel_cross_attention = multi_channel_cross_attention
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else:
self._cross_attention_cls = attention.MultiHeadAttention
def build(self, input_shape):
target_tensor_shape = tf.TensorShape(input_shape[0])
if len(target_tensor_shape) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
hidden_size = target_tensor_shape[2]
if hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size / self.num_attention_heads)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
# Self attention.
self.self_attention = attention.CachedAttention(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_dropout_rate,
use_bias=self._use_bias,
name="self_attention",
**common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
**common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate)
self.self_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon))
# Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_dropout_rate,
output_shape=hidden_size,
use_bias=self._use_bias,
name="attention/encdec",
**common_kwargs)
self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate)
self.encdec_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon))
# Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, self.intermediate_size),
bias_axes="d",
name="intermediate",
**common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation)
self.output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
**common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon)
super(TransformerDecoderLayer, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self.num_attention_heads,
"intermediate_size":
self.intermediate_size,
"intermediate_activation":
self.intermediate_activation,
"dropout_rate":
self.dropout_rate,
"attention_dropout_rate":
self.attention_dropout_rate,
"multi_channel_cross_attention":
self.multi_channel_cross_attention,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon
}
base_config = super(TransformerDecoderLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def common_layers_with_encoder(self):
"""Gets layer objects that can make a Transformer encoder block."""
return [
self.self_attention, self.self_attention_layer_norm,
self.intermediate_dense, self.output_dense, self.output_layer_norm
]
def call(self, inputs, cache=None, decode_loop_step=None):
if self.multi_channel_cross_attention:
if len(inputs) != 5:
raise ValueError(
"TransformerDecoderLayer must have 5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d" % len(inputs))
elif len(inputs) != 4:
raise ValueError(
"TransformerDecoderLayer must have 4 inputs, but it got: %d" %
len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
source_tensor = input_tensor
if self._norm_first:
input_tensor = self.self_attention_layer_norm(input_tensor)
self_attention_output, cache = self.self_attention(
query=input_tensor,
value=input_tensor,
attention_mask=self_attention_mask,
cache=cache,
decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output)
if self._norm_first:
self_attention_output = source_tensor + self_attention_output
else:
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
if self._norm_first:
source_self_attention_output = self_attention_output
self_attention_output = self.encdec_attention_layer_norm(
self_attention_output)
cross_attn_inputs = dict(
query=self_attention_output,
value=memory,
attention_mask=attention_mask)
if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output)
if self._norm_first:
attention_output = source_self_attention_output + attention_output
else:
attention_output = self.encdec_attention_layer_norm(
self_attention_output + attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self.output_layer_norm(attention_output)
intermediate_output = self.intermediate_dense(attention_output)
intermediate_output = self.intermediate_activation_layer(
intermediate_output)
layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self.output_layer_norm(layer_output + attention_output)
return layer_output, cache
...@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else: else:
input_tensor, attention_mask = (inputs, None) input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor] attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor + attention_output = self._attention_layer_norm(input_tensor +
attention_output) attention_output)
......
...@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention): ...@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super(ValidatedAttentionLayer, self).__init__(**kwargs) super(ValidatedAttentionLayer, self).__init__(**kwargs)
self.list = call_list self.list = call_list
def call(self, inputs, attention_mask=None): def call(self, query, value, attention_mask=None):
self.list.append(True) self.list.append(True)
return super(ValidatedAttentionLayer, self).call( return super(ValidatedAttentionLayer, self).call(
inputs, attention_mask=attention_mask) query, value, attention_mask=attention_mask)
def get_config(self): def get_config(self):
config = super(ValidatedAttentionLayer, self).get_config() config = super(ValidatedAttentionLayer, self).get_config()
......
...@@ -152,7 +152,8 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -152,7 +152,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_ = new_layer([input_data, mask_data]) _ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data]) new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :]) self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls): def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
...@@ -215,5 +216,113 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -215,5 +216,113 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertAllEqual([1, input_length, width], output_data.shape) self.assertAllEqual([1, input_length, width], output_data.shape)
@keras_parameterized.run_all_keras_modes
class TransformerArgumentTest(keras_parameterized.TestCase):
def test_use_bias_norm_first(self):
num_attention_heads = 2
hidden_size = 16
encoder_block = transformer.Transformer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_mask]
output = encoder_block(inputs)
self.assertEqual(output.shape, (2, 4, hidden_size))
def test_get_config(self):
num_attention_heads = 2
encoder_block = transformer.Transformer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
encoder_block_config = encoder_block.get_config()
new_encoder_block = transformer.Transformer.from_config(
encoder_block_config)
self.assertEqual(encoder_block_config, new_encoder_block.get_config())
def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return {
'key':
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32),
'value':
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32)
}
@keras_parameterized.run_all_keras_modes
class TransformerDecoderLayerTest(keras_parameterized.TestCase):
def test_decoder_block_with_cache(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_tensor, dummy_mask, dummy_mask]
cache = _create_cache(2, 0, num_attention_heads,
hidden_size // num_attention_heads)
output, cache = decoder_block(inputs, cache)
self.assertEqual(output.shape, (2, 4, hidden_size))
self.assertEqual(cache['value'].shape, (2, 4, 2, 8))
def test_use_bias_norm_first(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_tensor, dummy_mask, dummy_mask]
output, _ = decoder_block(inputs)
self.assertEqual(output.shape, (2, 4, hidden_size))
def test_get_config(self):
num_attention_heads = 2
decoder_block = transformer.TransformerDecoderLayer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
decoder_block_config = decoder_block.get_config()
new_decoder_block = transformer.TransformerDecoderLayer.from_config(
decoder_block_config)
self.assertEqual(decoder_block_config, new_decoder_block.get_config())
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -4,6 +4,3 @@ Losses contains common loss computation used in NLP tasks. ...@@ -4,6 +4,3 @@ Losses contains common loss computation used in NLP tasks.
* `weighted_sparse_categorical_crossentropy_loss` computes per-batch sparse * `weighted_sparse_categorical_crossentropy_loss` computes per-batch sparse
categorical crossentropy loss. categorical crossentropy loss.
* `weighted_sparse_categorical_crossentropy_per_example_loss` computes
per-example sparse categorical crossentropy loss.
...@@ -14,4 +14,3 @@ ...@@ -14,4 +14,3 @@
# ============================================================================== # ==============================================================================
"""Activations package definition. Subject to change.""" """Activations package definition. Subject to change."""
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import loss as weighted_sparse_categorical_crossentropy_loss from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import loss as weighted_sparse_categorical_crossentropy_loss
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import per_example_loss as weighted_sparse_categorical_crossentropy_per_example_loss
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Sparse categorical cross-entropy losses.""" """Weighted sparse categorical cross-entropy losses."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -43,37 +43,7 @@ def _validate_rank(labels, predictions, weights): ...@@ -43,37 +43,7 @@ def _validate_rank(labels, predictions, weights):
"predictions.shape was %s.") % (labels.shape, predictions.shape)) "predictions.shape was %s.") % (labels.shape, predictions.shape))
def per_example_loss(labels, predictions, weights=None): def loss(labels, predictions, weights=None, from_logits=False):
"""Calculate a per-example sparse categorical crossentropy loss.
This loss function assumes that the predictions are post-softmax.
Args:
labels: The labels to evaluate against. Should be a set of integer indices
ranging from 0 to (vocab_size-1).
predictions: The network predictions. Should have softmax already applied.
weights: An optional weight array of the same shape as the 'labels' array.
If None, all examples will be used.
Returns:
A tensor of shape predictions.shape[:-1] containing the per-example
loss.
"""
# When using these functions with the Keras core API, we will need to squeeze
# the labels tensor - Keras adds a spurious inner dimension.
labels, predictions = _adjust_labels(labels, predictions)
_validate_rank(labels, predictions, weights)
labels_one_hot = tf.one_hot(labels, predictions.shape[-1])
labels_one_hot = tf.cast(labels_one_hot, predictions.dtype)
per_example_loss_data = -tf.reduce_sum(
predictions * labels_one_hot, axis=[-1])
if weights is not None:
weights = tf.cast(weights, per_example_loss_data.dtype)
per_example_loss_data = weights * per_example_loss_data
return per_example_loss_data
def loss(labels, predictions, weights=None):
"""Calculate a per-batch sparse categorical crossentropy loss. """Calculate a per-batch sparse categorical crossentropy loss.
This loss function assumes that the predictions are post-softmax. This loss function assumes that the predictions are post-softmax.
...@@ -83,6 +53,7 @@ def loss(labels, predictions, weights=None): ...@@ -83,6 +53,7 @@ def loss(labels, predictions, weights=None):
predictions: The network predictions. Should have softmax already applied. predictions: The network predictions. Should have softmax already applied.
weights: An optional weight array of the same shape as the 'labels' array. weights: An optional weight array of the same shape as the 'labels' array.
If None, all examples will be used. If None, all examples will be used.
from_logits: Whether the input predictions are logits.
Returns: Returns:
A loss scalar. A loss scalar.
...@@ -95,12 +66,11 @@ def loss(labels, predictions, weights=None): ...@@ -95,12 +66,11 @@ def loss(labels, predictions, weights=None):
labels, predictions = _adjust_labels(labels, predictions) labels, predictions = _adjust_labels(labels, predictions)
_validate_rank(labels, predictions, weights) _validate_rank(labels, predictions, weights)
per_example_loss_data = per_example_loss(labels, predictions, weights) example_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels, predictions, from_logits=from_logits)
if weights is None: if weights is None:
return tf.reduce_mean(per_example_loss_data) return tf.reduce_mean(example_losses)
else: weights = tf.cast(weights, predictions.dtype)
numerator = tf.reduce_sum(per_example_loss_data) return tf.math.divide_no_nan(
weights = tf.cast(weights, predictions.dtype) tf.reduce_sum(example_losses * weights), tf.reduce_sum(weights))
denominator = tf.reduce_sum(weights) + 1e-5
return numerator / denominator
...@@ -53,8 +53,7 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -53,8 +53,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# Create a maskedLM from the transformer stack. # Create a maskedLM from the transformer stack.
test_layer = layers.MaskedLM( test_layer = layers.MaskedLM(
embedding_table=xformer_stack.get_embedding_table(), embedding_table=xformer_stack.get_embedding_table(), output=output)
output=output)
# Create a model from the masked LM layer. # Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size)) lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
...@@ -63,123 +62,6 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -63,123 +62,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
output = test_layer(lm_input_tensor, masked_positions=masked_lm_positions) output = test_layer(lm_input_tensor, masked_positions=masked_lm_positions)
return tf.keras.Model([lm_input_tensor, masked_lm_positions], output) return tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
def create_classification_model(self, input_width, num_classes):
test_object = networks.Classification(
input_width=input_width, num_classes=num_classes)
# Create a 2-dimensional input (the first dimension is implicit).
pooled_data = tf.keras.Input(shape=(input_width,), dtype=tf.float32)
output = test_object(pooled_data)
return tf.keras.Model(pooled_data, output)
def test_per_example_loss_3d_input(self):
"""Test per-example loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
model = self.create_lm_model(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Get the output of the masked LM.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
output_data = model.predict([lm_input_data, masked_position_data])
# Calculate per-example loss.
labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
# Per-example loss data should have one value per prediction, and those
# values shouldn't be zero in this case (as we're using random data).
expected_shape = [batch_size, num_predictions]
self.assertEqual(expected_shape, per_example_loss_data.shape.as_list())
self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_per_example_loss_2d_input(self):
"""Test per-example loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per example loss.
labels = np.random.randint(num_classes, size=(batch_size))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
# Per-example loss data should have one value per batch item, and those
# values shouldn't be zero in this case (as we're using random data).
self.assertEqual([batch_size], per_example_loss_data.shape.as_list())
self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_per_example_loss_weights_3d_input(self):
"""Test weighted per-example loss with a 3-d input, from a masked LM."""
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
model = self.create_lm_model(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Get the output of the masked LM.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
output_data = model.predict([lm_input_data, masked_position_data])
# Calculate per-example loss with weights.
labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
weights = np.random.randint(2, size=(batch_size, num_predictions))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
# Weighted per-example loss data should be equivalent to multiplying the
# loss tensor by the weights tensor.
expected_weighted_loss = per_example_loss_data * weights
self.assertAllClose(expected_weighted_loss, per_example_loss_data)
def test_per_example_loss_weights_2d_input(self):
"""Test weighted per-example loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per-example loss with weights.
labels = np.random.randint(num_classes, size=(batch_size))
weights = np.random.randint(2, size=(batch_size))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
# Weighted per-example loss data should be equivalent to multiplying the
# loss tensor by the weights tensor.
expected_weighted_loss = per_example_loss_data * weights
self.assertAllClose(expected_weighted_loss, per_example_loss_data)
def test_loss_3d_input(self): def test_loss_3d_input(self):
"""Test overall loss with a 3-dimensional input, from a masked LM.""" """Test overall loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100 vocab_size = 100
...@@ -213,26 +95,6 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -213,26 +95,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
self.assertNotAllClose( self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data) tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_loss_2d_input(self):
"""Test overall loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per example loss.
labels = np.random.randint(num_classes, size=(batch_size))
loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels)
# Loss data should have one value only, and that value shouldn't be zero in
# this case (as we're using random data).
self.assertNotAllClose(0, loss_data)
def test_loss_weights_3d_input(self): def test_loss_weights_3d_input(self):
"""Test masked loss with a 3-dimensional input, from a masked LM.""" """Test masked loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100 vocab_size = 100
...@@ -262,26 +124,6 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -262,26 +124,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# Because the tensor is fully masked, the loss should be 0. # Because the tensor is fully masked, the loss should be 0.
self.assertAllClose(0, weighted_loss_data) self.assertAllClose(0, weighted_loss_data)
def test_loss_weights_2d_input(self):
"""Test masked loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate a fully masked weight tensor. This should give a loss of zero.
labels = np.random.randint(num_classes, size=(batch_size))
null_weights = np.zeros((batch_size))
weighted_loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=null_weights)
# Because the tensor is fully masked, the loss should be 0.
self.assertAllClose(0, weighted_loss_data)
def test_mismatched_predictions_and_labels_ranks_squeezes(self): def test_mismatched_predictions_and_labels_ranks_squeezes(self):
"""Test that the loss asserts when rank(predictions)-1 != rank(labels).""" """Test that the loss asserts when rank(predictions)-1 != rank(labels)."""
batch_size = 3 batch_size = 3
...@@ -289,7 +131,7 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -289,7 +131,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
labels = np.random.randint(10, size=(batch_size, 1)) labels = np.random.randint(10, size=(batch_size, 1))
# All that this test tests is that the squeeze is successful. # All that this test tests is that the squeeze is successful.
_ = weighted_sparse_categorical_crossentropy.per_example_loss( _ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels) predictions=output_data, labels=labels)
def test_mismatched_weights_and_labels_ranks_fail(self): def test_mismatched_weights_and_labels_ranks_fail(self):
...@@ -299,9 +141,6 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -299,9 +141,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
labels = np.random.randint(10, size=(batch_size, 10)) labels = np.random.randint(10, size=(batch_size, 10))
weights = np.random.randint(2, size=(batch_size)) weights = np.random.randint(2, size=(batch_size))
with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
_ = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"): with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
_ = weighted_sparse_categorical_crossentropy.loss( _ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights) predictions=output_data, labels=labels, weights=weights)
...@@ -317,8 +156,6 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -317,8 +156,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# We're not trying to validate numerical correctness, just ensure that # We're not trying to validate numerical correctness, just ensure that
# we can in fact pass tensors to these functions without causing runtime # we can in fact pass tensors to these functions without causing runtime
# errors from the shape checking code. # errors from the shape checking code.
_ = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
_ = weighted_sparse_categorical_crossentropy.loss( _ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights) predictions=output_data, labels=labels, weights=weights)
...@@ -338,20 +175,15 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -338,20 +175,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
[-2.7760355, -1.8219438, -3.0924666, -1.0779881, -0.9407509]]]) [-2.7760355, -1.8219438, -3.0924666, -1.0779881, -0.9407509]]])
labels = np.array([[4, 0], [2, 2], [2, 1]]) labels = np.array([[4, 0], [2, 2], [2, 1]])
# Validate that per_example loss calculations are the same.
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
expected_per_example_loss_data = [[1.2923571, 2.7117882],
[2.287932, 2.287932],
[3.0924666, 1.8219438]]
self.assertAllClose(expected_per_example_loss_data, per_example_loss_data)
# Validate that overall loss calculations are the same. # Validate that overall loss calculations are the same.
weights = np.array([[1, 0], [0, 0], [0, 0]]) weights = np.array([[1, 0], [0, 0], [0, 0]])
loss_data = weighted_sparse_categorical_crossentropy.loss( loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights) predictions=output_data,
labels=labels,
weights=weights,
from_logits=True)
expected_loss_data = 1.2923441 expected_loss_data = 1.2923441
self.assertAllClose(expected_loss_data, loss_data) self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3)
def test_legacy_classification_loss_compatibility(self): def test_legacy_classification_loss_compatibility(self):
"""Test to validate computational correctness during refactors.""" """Test to validate computational correctness during refactors."""
...@@ -362,19 +194,15 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -362,19 +194,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
[-1.6975292e-03, -6.4009643e+00, -1.0226612e+01]]) [-1.6975292e-03, -6.4009643e+00, -1.0226612e+01]])
labels = np.array([2, 1]) labels = np.array([2, 1])
# Validate that per_example loss calculations are the same.
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
expected_per_example_loss_data = [6.4434357, 6.4009643]
self.assertAllClose(expected_per_example_loss_data, per_example_loss_data)
# Validate that overall loss calculations are the same. # Validate that overall loss calculations are the same.
weights = None weights = None
loss_data = weighted_sparse_categorical_crossentropy.loss( loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights) predictions=output_data,
labels=labels,
weights=weights,
from_logits=True)
expected_loss_data = 6.4222 expected_loss_data = 6.4222
self.assertAllClose(expected_loss_data, loss_data) self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -10,8 +10,8 @@ model containing a single classification head using the Classification network. ...@@ -10,8 +10,8 @@ model containing a single classification head using the Classification network.
It can be used as a regression model as well. It can be used as a regression model as well.
* [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token * [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token
classification model containing a single classification head using the classification model containing a single classification head over the sequence
TokenClassification network. output embeddings.
* [`BertSpanLabeler`](bert_span_labeler.py) implementats a simple single-span * [`BertSpanLabeler`](bert_span_labeler.py) implementats a simple single-span
start-end predictor (that is, a model that predicts two values: a start token start-end predictor (that is, a model that predicts two values: a start token
......
...@@ -17,3 +17,4 @@ from official.nlp.modeling.models.bert_classifier import BertClassifier ...@@ -17,3 +17,4 @@ from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import BertPretrainer from official.nlp.modeling.models.bert_pretrainer import BertPretrainer
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
...@@ -12,15 +12,12 @@ ...@@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Trainer network for BERT-style models.""" """BERT cls-token classifier."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -36,6 +33,9 @@ class BertClassifier(tf.keras.Model): ...@@ -36,6 +33,9 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes` instantiates a classification network based on the passed `num_classes`
argument. If `num_classes` is set to 1, a regression network is instantiated. argument. If `num_classes` is set to 1, a regression network is instantiated.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output. Furthermore, it should expose its embedding
...@@ -43,23 +43,25 @@ class BertClassifier(tf.keras.Model): ...@@ -43,23 +43,25 @@ class BertClassifier(tf.keras.Model):
num_classes: Number of classes to predict from the classification network. num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks. initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer. Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or dropout_rate: The dropout probability of the cls head.
'predictions'. use_encoder_pooler: Whether to use the pooler layer pre-defined inside
the encoder.
""" """
def __init__(self, def __init__(self,
network, network,
num_classes, num_classes,
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits',
dropout_rate=0.1, dropout_rate=0.1,
use_encoder_pooler=True,
**kwargs): **kwargs):
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._network = network
self._config = { self._config = {
'network': network, 'network': network,
'num_classes': num_classes, 'num_classes': num_classes,
'initializer': initializer, 'initializer': initializer,
'output': output, 'use_encoder_pooler': use_encoder_pooler,
} }
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
...@@ -67,22 +69,36 @@ class BertClassifier(tf.keras.Model): ...@@ -67,22 +69,36 @@ class BertClassifier(tf.keras.Model):
# when we construct the Model object at the end of init. # when we construct the Model object at the end of init.
inputs = network.inputs inputs = network.inputs
# Because we have a copy of inputs to create this Model object, we can if use_encoder_pooler:
# invoke the Network object with its own input tensors to start the Model. # Because we have a copy of inputs to create this Model object, we can
_, cls_output = network(inputs) # invoke the Network object with its own input tensors to start the Model.
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output) _, cls_output = network(inputs)
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output)
self.classifier = networks.Classification( self.classifier = networks.Classification(
input_width=cls_output.shape[-1], input_width=cls_output.shape[-1],
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=initializer,
output=output, output='logits',
name='classification') name='sentence_prediction')
predictions = self.classifier(cls_output) predictions = self.classifier(cls_output)
else:
sequence_output, _ = network(inputs)
self.classifier = layers.ClassificationHead(
inner_dim=sequence_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
name='sentence_prediction')
predictions = self.classifier(sequence_output)
super(BertClassifier, self).__init__( super(BertClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs) inputs=inputs, outputs=predictions, **kwargs)
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self): def get_config(self):
return self._config return self._config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment