Unverified Commit 0cceabfc authored by Yiming Shi's avatar Yiming Shi Committed by GitHub
Browse files

Merge branch 'master' into move_to_keraslayers_fasterrcnn_fpn_keras_feature_extractor

parents 17821c0d 39ee0ac9
......@@ -34,7 +34,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
def create_layer(self,
vocab_size,
sequence_length,
hidden_size,
output='predictions',
xformer_stack=None):
......@@ -44,7 +43,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
xformer_stack = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size,
num_layers=1,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_attention_heads=4,
)
......@@ -62,7 +60,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
num_predictions = 21
test_layer = self.create_layer(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size)
# Make sure that the output tensor of the masked LM is the right shape.
......@@ -81,19 +78,16 @@ class MaskedLMTest(keras_parameterized.TestCase):
xformer_stack = transformer_encoder.TransformerEncoder(
vocab_size=vocab_size,
num_layers=1,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_attention_heads=4,
)
test_layer = self.create_layer(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
xformer_stack=xformer_stack,
output='predictions')
logit_layer = self.create_layer(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
xformer_stack=xformer_stack,
output='logits')
......@@ -134,7 +128,6 @@ class MaskedLMTest(keras_parameterized.TestCase):
num_predictions = 21
test_layer = self.create_layer(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size)
# Create a model from the masked LM layer.
......@@ -155,7 +148,7 @@ class MaskedLMTest(keras_parameterized.TestCase):
def test_unknown_output_type_fails(self):
with self.assertRaisesRegex(ValueError, 'Unknown `output` value "bad".*'):
_ = self.create_layer(
vocab_size=8, sequence_length=8, hidden_size=8, output='bad')
vocab_size=8, hidden_size=8, output='bad')
if __name__ == '__main__':
......
......@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Multi-channel decoder."""
"""Multi-channel Attention."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
......@@ -24,11 +25,24 @@ import math
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import masked_softmax
class DocAttention(tf.keras.layers.Layer):
"""Documents Attention layer."""
class VotingAttention(tf.keras.layers.Layer):
"""Voting Attention layer.
Arguments:
num_heads: the number of attention heads.
head_size: per-head hidden size.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def __init__(self,
num_heads,
......@@ -41,7 +55,7 @@ class DocAttention(tf.keras.layers.Layer):
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(DocAttention, self).__init__(**kwargs)
super(VotingAttention, self).__init__(**kwargs)
self._num_heads = num_heads
self._head_size = head_size
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
......@@ -52,29 +66,27 @@ class DocAttention(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, unused_input_shapes):
self._query_dense = layers.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="encdocatt_query")
self._key_dense = layers.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="encdocatt_key")
super(DocAttention, self).build(unused_input_shapes)
bias_constraint=self._bias_constraint)
self._query_dense = tf.keras.layers.experimental.EinsumDense(
"BAE,ENH->BANH",
output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH",
name="query",
**common_kwargs)
self._key_dense = tf.keras.layers.experimental.EinsumDense(
"BAE,ENH->BANH",
output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH",
name="key",
**common_kwargs)
super(VotingAttention, self).build(unused_input_shapes)
def call(self, encoder_outputs, doc_attention_mask):
num_docs = tf_utils.get_shape_list(encoder_outputs, expected_rank=[4])[1]
......@@ -95,33 +107,55 @@ class DocAttention(tf.keras.layers.Layer):
return tf.nn.softmax(doc_attention_probs + infadder)
class MultiChannelAttention(layers.MultiHeadAttention):
"""Multi-channel Attention layer."""
def build(self, input_shape):
super(MultiChannelAttention, self).build(input_shape)
self._masked_softmax = layers.MaskedSoftmax(mask_expansion_axes=[2])
def call(self, inputs, attention_mask=None):
from_tensor = inputs[0]
to_tensor = inputs[1]
doc_attention_probs = inputs[2]
class MultiChannelAttention(attention.MultiHeadAttention):
"""Multi-channel Attention layer.
Introduced in, [Generating Representative Headlines for News Stories
](https://arxiv.org/abs/2001.09386). Expects multiple cross-attention
target sequences.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, A, S, dim]`, where A denotes the
context_attention_weights: Context weights of shape `[B, N, T, A]`, where N
is the number of attention heads. Combines multi-channel sources
context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def build_attention(self, rank):
super(MultiChannelAttention, self).build_attention(rank)
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self,
query,
value,
key=None,
context_attention_weights=None,
attention_mask=None):
if not self._built_from_signature:
self._build_from_signature(query, value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here:
# B = batch size (number of stories)
# A = num_docs (number of docs)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# F = target sequence length
# T = source sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
query_tensor = self._query_dense(query)
# `key_tensor` = [B, A, T, N, H]
key_tensor = self._key_dense(to_tensor)
key_tensor = self._key_dense(key)
# `value_tensor` = [B, A, T, N, H]
value_tensor = self._value_dense(to_tensor)
value_tensor = self._value_dense(value)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
......@@ -140,7 +174,7 @@ class MultiChannelAttention(layers.MultiHeadAttention):
# `context_layer` = [B, F, N, H]
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs,
value_tensor)
attention_output = tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs,
attention_output = tf.einsum("BNFA,BAFNH->BFNH", context_attention_weights,
context_layer)
attention_output = self._output_dense(attention_output)
return attention_output
......@@ -22,14 +22,15 @@ from __future__ import print_function
import numpy as np
import tensorflow as tf
from official.nlp.nhnet import multi_channel_attention
from official.nlp.modeling.layers import multi_channel_attention
class MultiChannelAttentionTest(tf.test.TestCase):
def test_doc_attention(self):
num_heads = 2
doc_attention = multi_channel_attention.DocAttention(num_heads, head_size=8)
doc_attention = multi_channel_attention.VotingAttention(
num_heads, head_size=8)
num_docs = 3
inputs = np.zeros((2, num_docs, 10, 16), dtype=np.float32)
doc_mask = np.zeros((2, num_docs), dtype=np.float32)
......@@ -47,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
mask_data = np.random.randint(2, size=(3, num_docs, 4, 2))
doc_probs = np.random.randint(
2, size=(3, num_heads, 4, num_docs)).astype(float)
outputs = attention_layer([from_data, to_data, doc_probs], mask_data)
outputs = attention_layer(
query=from_data,
value=to_data,
context_attention_weights=doc_probs,
attention_mask=mask_data)
self.assertEqual(outputs.shape, (3, 4, 8))
......
......@@ -38,6 +38,9 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
lookup. Defaults to False (that is, using tf.gather). Setting this option
to True may improve performance, especially on small vocabulary sizes, but
will generally require more memory.
use_scale: Whether to scale the output embeddings. Defaults to False (that
is, not to scale). Setting this option to True will let values in output
embeddings multiplied by self._embedding_width ** 0.5.
"""
def __init__(self,
......@@ -45,6 +48,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
embedding_width,
initializer="glorot_uniform",
use_one_hot=False,
use_scale=False,
**kwargs):
super(OnDeviceEmbedding, self).__init__(**kwargs)
......@@ -52,6 +56,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
self._embedding_width = embedding_width
self._initializer = initializer
self._use_one_hot = use_one_hot
self._use_scale = use_scale
def get_config(self):
config = {
......@@ -59,6 +64,7 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
"embedding_width": self._embedding_width,
"initializer": self._initializer,
"use_one_hot": self._use_one_hot,
"use_scale": self._use_scale,
}
base_config = super(OnDeviceEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......@@ -85,4 +91,6 @@ class OnDeviceEmbedding(tf.keras.layers.Layer):
# Work around b/142213824: prefer concat to shape over a Python list.
tf.concat([tf.shape(inputs), [self._embedding_width]], axis=0))
embeddings.set_shape(inputs.shape.as_list() + [self._embedding_width])
if self._use_scale:
embeddings *= self._embedding_width ** 0.5
return embeddings
......@@ -193,6 +193,26 @@ class OnDeviceEmbeddingTest(keras_parameterized.TestCase):
output = model.predict(input_data)
self.assertEqual(tf.float16, output.dtype)
def test_use_scale_layer_invocation(self):
vocab_size = 31
embedding_width = 27
test_layer = on_device_embedding.OnDeviceEmbedding(
vocab_size=vocab_size, embedding_width=embedding_width, use_scale=True)
# Create a 2-dimensional input (the first dimension is implicit).
sequence_length = 23
input_tensor = tf.keras.Input(shape=(sequence_length), dtype=tf.int32)
output_tensor = test_layer(input_tensor)
# Create a model from the test layer.
model = tf.keras.Model(input_tensor, output_tensor)
# Invoke the model on test data. We can't validate the output data itself
# (the NN is too complex) but this will rule out structural runtime errors.
batch_size = 3
input_data = np.random.randint(
vocab_size, size=(batch_size, sequence_length))
output = model.predict(input_data)
self.assertEqual(tf.float32, output.dtype)
if __name__ == "__main__":
tf.test.main()
......@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
"hidden_size": self._hidden_size,
"min_timescale": self._min_timescale,
"max_timescale": self._max_timescale,
"length": self._length,
}
base_config = super(RelativePositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......
......@@ -23,7 +23,6 @@ import gin
import tensorflow as tf
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
@tf.keras.utils.register_keras_serializable(package="Text")
......@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention")
bias_constraint=self._bias_constraint)
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
# Use float32 in layernorm for numeric stability.
......@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis=-1,
epsilon=1e-12,
dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum(
output_shape=self._intermediate_size,
activation=None,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
bias_axes="d",
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
......@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
# Use float32 in layernorm for numeric stability.
......@@ -222,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = target_tensor + self._rezero_a * attention_output
if self._use_layer_norm:
......
......@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
bias_constraint: Constraint for dense layer kernels.
"""
def _build_attention(self, qkv_rank):
def build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection
......@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Args:
qkv_rank: the rank of query, key, value tensors after projection.
"""
super(TalkingHeadsAttention, self)._build_attention(qkv_rank)
super(TalkingHeadsAttention, self).build_attention(qkv_rank)
# Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
......@@ -103,11 +103,11 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
dtype=self.dtype,
trainable=True)
def _compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
def compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection
......
......@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value])
output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self):
......@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self):
......@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query])
output, coef = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
......@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor)
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output)
......@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor)
output = test_layer(
query=query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
......@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters(
......@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data)
output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data)
unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(output, unmasked_output)
......
......@@ -23,7 +23,7 @@ import gin
import tensorflow as tf
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers.util import tf_function_if_eager
......@@ -49,6 +49,12 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
"""
def __init__(self,
......@@ -65,6 +71,9 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
**kwargs):
super(Transformer, self).__init__(**kwargs)
......@@ -78,8 +87,12 @@ class Transformer(tf.keras.layers.Layer):
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
......@@ -104,23 +117,21 @@ class Transformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention")
# pylint: disable=protected-access
self._attention_layer.build([input_tensor_shape] * 3)
self._attention_output_dense = self._attention_layer._output_dense
# pylint: enable=protected-access
bias_constraint=self._bias_constraint)
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
......@@ -128,19 +139,14 @@ class Transformer(tf.keras.layers.Layer):
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=1e-12,
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum(
output_shape=self._intermediate_size,
activation=None,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
bias_axes="d",
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
......@@ -149,20 +155,19 @@ class Transformer(tf.keras.layers.Layer):
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12, dtype=tf.float32)
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
super(Transformer, self).build(input_shape)
......@@ -193,7 +198,13 @@ class Transformer(tf.keras.layers.Layer):
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon
}
base_config = super(Transformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......@@ -208,13 +219,22 @@ class Transformer(tf.keras.layers.Layer):
target_tensor = input_tensor[:, 0:self._output_range, :]
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
attention_output = source_tensor + attention_output
else:
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer(
intermediate_output)
......@@ -224,7 +244,10 @@ class Transformer(tf.keras.layers.Layer):
# is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add.
layer_output = tf.cast(layer_output, tf.float32)
layer_output = self._output_layer_norm(layer_output + attention_output)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self._output_layer_norm(layer_output + attention_output)
return layer_output
......@@ -236,3 +259,259 @@ class CompiledTransformer(Transformer):
@tf_function_if_eager(experimental_compile=True)
def call(self, inputs):
return super(CompiledTransformer, self).call(inputs)
@tf.keras.utils.register_keras_serializable(package="Text")
class TransformerDecoderLayer(tf.keras.layers.Layer):
"""Single transformer layer for decoder.
It has three sub-layers:
(1) a multi-head self-attention mechanism.
(2) a encoder-decoder attention.
(3) a positionwise fully connected feed-forward network.
Arguments:
num_attention_heads: Number of attention heads.
intermediate_size: Size of the intermediate layer.
intermediate_activation: Activation for the intermediate layer.
dropout_rate: Dropout probability for the post-attention and output dropout.
attention_dropout_rate: Dropout probability for within the attention layer.
multi_channel_cross_attention: Whether to use `MultiChannelAttention` for
cross-attention between target sequences and source sequences.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers. If set False, output of attention and intermediate dense layers is
normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
"""
def __init__(self,
num_attention_heads,
intermediate_size,
intermediate_activation,
dropout_rate=0.0,
attention_dropout_rate=0.0,
multi_channel_cross_attention=False,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
**kwargs):
super(TransformerDecoderLayer, self).__init__(**kwargs)
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_activation = tf.keras.activations.get(
intermediate_activation)
self.dropout_rate = dropout_rate
self.attention_dropout_rate = attention_dropout_rate
self.multi_channel_cross_attention = multi_channel_cross_attention
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else:
self._cross_attention_cls = attention.MultiHeadAttention
def build(self, input_shape):
target_tensor_shape = tf.TensorShape(input_shape[0])
if len(target_tensor_shape) != 3:
raise ValueError("TransformerLayer expects a three-dimensional input of "
"shape [batch, sequence, width].")
hidden_size = target_tensor_shape[2]
if hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size / self.num_attention_heads)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
# Self attention.
self.self_attention = attention.CachedAttention(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_dropout_rate,
use_bias=self._use_bias,
name="self_attention",
**common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
**common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate)
self.self_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon))
# Encoder-decoder attention.
self.encdec_attention = self._cross_attention_cls(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_dropout_rate,
output_shape=hidden_size,
use_bias=self._use_bias,
name="attention/encdec",
**common_kwargs)
self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate)
self.encdec_attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="attention/encdec_output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon))
# Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, self.intermediate_size),
bias_axes="d",
name="intermediate",
**common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation)
self.output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
**common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=self._norm_epsilon)
super(TransformerDecoderLayer, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self.num_attention_heads,
"intermediate_size":
self.intermediate_size,
"intermediate_activation":
self.intermediate_activation,
"dropout_rate":
self.dropout_rate,
"attention_dropout_rate":
self.attention_dropout_rate,
"multi_channel_cross_attention":
self.multi_channel_cross_attention,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon
}
base_config = super(TransformerDecoderLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def common_layers_with_encoder(self):
"""Gets layer objects that can make a Transformer encoder block."""
return [
self.self_attention, self.self_attention_layer_norm,
self.intermediate_dense, self.output_dense, self.output_layer_norm
]
def call(self, inputs, cache=None, decode_loop_step=None):
if self.multi_channel_cross_attention:
if len(inputs) != 5:
raise ValueError(
"TransformerDecoderLayer must have 5 inputs, when it uses "
"multi_channel_cross_attention. But it got: %d" % len(inputs))
elif len(inputs) != 4:
raise ValueError(
"TransformerDecoderLayer must have 4 inputs, but it got: %d" %
len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
source_tensor = input_tensor
if self._norm_first:
input_tensor = self.self_attention_layer_norm(input_tensor)
self_attention_output, cache = self.self_attention(
query=input_tensor,
value=input_tensor,
attention_mask=self_attention_mask,
cache=cache,
decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output)
if self._norm_first:
self_attention_output = source_tensor + self_attention_output
else:
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
if self._norm_first:
source_self_attention_output = self_attention_output
self_attention_output = self.encdec_attention_layer_norm(
self_attention_output)
cross_attn_inputs = dict(
query=self_attention_output,
value=memory,
attention_mask=attention_mask)
if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output)
if self._norm_first:
attention_output = source_self_attention_output + attention_output
else:
attention_output = self.encdec_attention_layer_norm(
self_attention_output + attention_output)
if self._norm_first:
source_attention_output = attention_output
attention_output = self.output_layer_norm(attention_output)
intermediate_output = self.intermediate_dense(attention_output)
intermediate_output = self.intermediate_activation_layer(
intermediate_output)
layer_output = self.output_dense(intermediate_output)
layer_output = self.output_dropout(layer_output)
if self._norm_first:
layer_output = source_attention_output + layer_output
else:
layer_output = self.output_layer_norm(layer_output + attention_output)
return layer_output, cache
......@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else:
input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor +
attention_output)
......
......@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super(ValidatedAttentionLayer, self).__init__(**kwargs)
self.list = call_list
def call(self, inputs, attention_mask=None):
def call(self, query, value, attention_mask=None):
self.list.append(True)
return super(ValidatedAttentionLayer, self).call(
inputs, attention_mask=attention_mask)
query, value, attention_mask=attention_mask)
def get_config(self):
config = super(ValidatedAttentionLayer, self).get_config()
......
......@@ -152,7 +152,8 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :])
self.assertAllClose(
new_output_tensor, output_tensor[:, 0:1, :], atol=5e-5, rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
......@@ -215,5 +216,113 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertAllEqual([1, input_length, width], output_data.shape)
@keras_parameterized.run_all_keras_modes
class TransformerArgumentTest(keras_parameterized.TestCase):
def test_use_bias_norm_first(self):
num_attention_heads = 2
hidden_size = 16
encoder_block = transformer.Transformer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_mask]
output = encoder_block(inputs)
self.assertEqual(output.shape, (2, 4, hidden_size))
def test_get_config(self):
num_attention_heads = 2
encoder_block = transformer.Transformer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
encoder_block_config = encoder_block.get_config()
new_encoder_block = transformer.Transformer.from_config(
encoder_block_config)
self.assertEqual(encoder_block_config, new_encoder_block.get_config())
def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return {
'key':
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32),
'value':
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32)
}
@keras_parameterized.run_all_keras_modes
class TransformerDecoderLayerTest(keras_parameterized.TestCase):
def test_decoder_block_with_cache(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_tensor, dummy_mask, dummy_mask]
cache = _create_cache(2, 0, num_attention_heads,
hidden_size // num_attention_heads)
output, cache = decoder_block(inputs, cache)
self.assertEqual(output.shape, (2, 4, hidden_size))
self.assertEqual(cache['value'].shape, (2, 4, 2, 8))
def test_use_bias_norm_first(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_tensor, dummy_mask, dummy_mask]
output, _ = decoder_block(inputs)
self.assertEqual(output.shape, (2, 4, hidden_size))
def test_get_config(self):
num_attention_heads = 2
decoder_block = transformer.TransformerDecoderLayer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1,
use_bias=False,
norm_first=True,
norm_epsilon=1e-6)
decoder_block_config = decoder_block.get_config()
new_decoder_block = transformer.TransformerDecoderLayer.from_config(
decoder_block_config)
self.assertEqual(decoder_block_config, new_decoder_block.get_config())
if __name__ == '__main__':
tf.test.main()
......@@ -4,6 +4,3 @@ Losses contains common loss computation used in NLP tasks.
* `weighted_sparse_categorical_crossentropy_loss` computes per-batch sparse
categorical crossentropy loss.
* `weighted_sparse_categorical_crossentropy_per_example_loss` computes
per-example sparse categorical crossentropy loss.
......@@ -14,4 +14,3 @@
# ==============================================================================
"""Activations package definition. Subject to change."""
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import loss as weighted_sparse_categorical_crossentropy_loss
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import per_example_loss as weighted_sparse_categorical_crossentropy_per_example_loss
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sparse categorical cross-entropy losses."""
"""Weighted sparse categorical cross-entropy losses."""
from __future__ import absolute_import
from __future__ import division
......@@ -43,37 +43,7 @@ def _validate_rank(labels, predictions, weights):
"predictions.shape was %s.") % (labels.shape, predictions.shape))
def per_example_loss(labels, predictions, weights=None):
"""Calculate a per-example sparse categorical crossentropy loss.
This loss function assumes that the predictions are post-softmax.
Args:
labels: The labels to evaluate against. Should be a set of integer indices
ranging from 0 to (vocab_size-1).
predictions: The network predictions. Should have softmax already applied.
weights: An optional weight array of the same shape as the 'labels' array.
If None, all examples will be used.
Returns:
A tensor of shape predictions.shape[:-1] containing the per-example
loss.
"""
# When using these functions with the Keras core API, we will need to squeeze
# the labels tensor - Keras adds a spurious inner dimension.
labels, predictions = _adjust_labels(labels, predictions)
_validate_rank(labels, predictions, weights)
labels_one_hot = tf.one_hot(labels, predictions.shape[-1])
labels_one_hot = tf.cast(labels_one_hot, predictions.dtype)
per_example_loss_data = -tf.reduce_sum(
predictions * labels_one_hot, axis=[-1])
if weights is not None:
weights = tf.cast(weights, per_example_loss_data.dtype)
per_example_loss_data = weights * per_example_loss_data
return per_example_loss_data
def loss(labels, predictions, weights=None):
def loss(labels, predictions, weights=None, from_logits=False):
"""Calculate a per-batch sparse categorical crossentropy loss.
This loss function assumes that the predictions are post-softmax.
......@@ -83,6 +53,7 @@ def loss(labels, predictions, weights=None):
predictions: The network predictions. Should have softmax already applied.
weights: An optional weight array of the same shape as the 'labels' array.
If None, all examples will be used.
from_logits: Whether the input predictions are logits.
Returns:
A loss scalar.
......@@ -95,12 +66,11 @@ def loss(labels, predictions, weights=None):
labels, predictions = _adjust_labels(labels, predictions)
_validate_rank(labels, predictions, weights)
per_example_loss_data = per_example_loss(labels, predictions, weights)
example_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels, predictions, from_logits=from_logits)
if weights is None:
return tf.reduce_mean(per_example_loss_data)
else:
numerator = tf.reduce_sum(per_example_loss_data)
weights = tf.cast(weights, predictions.dtype)
denominator = tf.reduce_sum(weights) + 1e-5
return numerator / denominator
return tf.reduce_mean(example_losses)
weights = tf.cast(weights, predictions.dtype)
return tf.math.divide_no_nan(
tf.reduce_sum(example_losses * weights), tf.reduce_sum(weights))
......@@ -53,8 +53,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# Create a maskedLM from the transformer stack.
test_layer = layers.MaskedLM(
embedding_table=xformer_stack.get_embedding_table(),
output=output)
embedding_table=xformer_stack.get_embedding_table(), output=output)
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
......@@ -63,123 +62,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
output = test_layer(lm_input_tensor, masked_positions=masked_lm_positions)
return tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
def create_classification_model(self, input_width, num_classes):
test_object = networks.Classification(
input_width=input_width, num_classes=num_classes)
# Create a 2-dimensional input (the first dimension is implicit).
pooled_data = tf.keras.Input(shape=(input_width,), dtype=tf.float32)
output = test_object(pooled_data)
return tf.keras.Model(pooled_data, output)
def test_per_example_loss_3d_input(self):
"""Test per-example loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
model = self.create_lm_model(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Get the output of the masked LM.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
output_data = model.predict([lm_input_data, masked_position_data])
# Calculate per-example loss.
labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
# Per-example loss data should have one value per prediction, and those
# values shouldn't be zero in this case (as we're using random data).
expected_shape = [batch_size, num_predictions]
self.assertEqual(expected_shape, per_example_loss_data.shape.as_list())
self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_per_example_loss_2d_input(self):
"""Test per-example loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per example loss.
labels = np.random.randint(num_classes, size=(batch_size))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
# Per-example loss data should have one value per batch item, and those
# values shouldn't be zero in this case (as we're using random data).
self.assertEqual([batch_size], per_example_loss_data.shape.as_list())
self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_per_example_loss_weights_3d_input(self):
"""Test weighted per-example loss with a 3-d input, from a masked LM."""
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
model = self.create_lm_model(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Get the output of the masked LM.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
output_data = model.predict([lm_input_data, masked_position_data])
# Calculate per-example loss with weights.
labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
weights = np.random.randint(2, size=(batch_size, num_predictions))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
# Weighted per-example loss data should be equivalent to multiplying the
# loss tensor by the weights tensor.
expected_weighted_loss = per_example_loss_data * weights
self.assertAllClose(expected_weighted_loss, per_example_loss_data)
def test_per_example_loss_weights_2d_input(self):
"""Test weighted per-example loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per-example loss with weights.
labels = np.random.randint(num_classes, size=(batch_size))
weights = np.random.randint(2, size=(batch_size))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
# Weighted per-example loss data should be equivalent to multiplying the
# loss tensor by the weights tensor.
expected_weighted_loss = per_example_loss_data * weights
self.assertAllClose(expected_weighted_loss, per_example_loss_data)
def test_loss_3d_input(self):
"""Test overall loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100
......@@ -213,26 +95,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_loss_2d_input(self):
"""Test overall loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per example loss.
labels = np.random.randint(num_classes, size=(batch_size))
loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels)
# Loss data should have one value only, and that value shouldn't be zero in
# this case (as we're using random data).
self.assertNotAllClose(0, loss_data)
def test_loss_weights_3d_input(self):
"""Test masked loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100
......@@ -262,26 +124,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# Because the tensor is fully masked, the loss should be 0.
self.assertAllClose(0, weighted_loss_data)
def test_loss_weights_2d_input(self):
"""Test masked loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate a fully masked weight tensor. This should give a loss of zero.
labels = np.random.randint(num_classes, size=(batch_size))
null_weights = np.zeros((batch_size))
weighted_loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=null_weights)
# Because the tensor is fully masked, the loss should be 0.
self.assertAllClose(0, weighted_loss_data)
def test_mismatched_predictions_and_labels_ranks_squeezes(self):
"""Test that the loss asserts when rank(predictions)-1 != rank(labels)."""
batch_size = 3
......@@ -289,7 +131,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
labels = np.random.randint(10, size=(batch_size, 1))
# All that this test tests is that the squeeze is successful.
_ = weighted_sparse_categorical_crossentropy.per_example_loss(
_ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels)
def test_mismatched_weights_and_labels_ranks_fail(self):
......@@ -299,9 +141,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
labels = np.random.randint(10, size=(batch_size, 10))
weights = np.random.randint(2, size=(batch_size))
with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
_ = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
_ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights)
......@@ -317,8 +156,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# We're not trying to validate numerical correctness, just ensure that
# we can in fact pass tensors to these functions without causing runtime
# errors from the shape checking code.
_ = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
_ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights)
......@@ -338,20 +175,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
[-2.7760355, -1.8219438, -3.0924666, -1.0779881, -0.9407509]]])
labels = np.array([[4, 0], [2, 2], [2, 1]])
# Validate that per_example loss calculations are the same.
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
expected_per_example_loss_data = [[1.2923571, 2.7117882],
[2.287932, 2.287932],
[3.0924666, 1.8219438]]
self.assertAllClose(expected_per_example_loss_data, per_example_loss_data)
# Validate that overall loss calculations are the same.
weights = np.array([[1, 0], [0, 0], [0, 0]])
loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights)
predictions=output_data,
labels=labels,
weights=weights,
from_logits=True)
expected_loss_data = 1.2923441
self.assertAllClose(expected_loss_data, loss_data)
self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3)
def test_legacy_classification_loss_compatibility(self):
"""Test to validate computational correctness during refactors."""
......@@ -362,19 +194,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
[-1.6975292e-03, -6.4009643e+00, -1.0226612e+01]])
labels = np.array([2, 1])
# Validate that per_example loss calculations are the same.
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
expected_per_example_loss_data = [6.4434357, 6.4009643]
self.assertAllClose(expected_per_example_loss_data, per_example_loss_data)
# Validate that overall loss calculations are the same.
weights = None
loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights)
predictions=output_data,
labels=labels,
weights=weights,
from_logits=True)
expected_loss_data = 6.4222
self.assertAllClose(expected_loss_data, loss_data)
self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3)
if __name__ == "__main__":
tf.test.main()
......@@ -10,8 +10,8 @@ model containing a single classification head using the Classification network.
It can be used as a regression model as well.
* [`BertTokenClassifier`](bert_token_classifier.py) implements a simple token
classification model containing a single classification head using the
TokenClassification network.
classification model containing a single classification head over the sequence
output embeddings.
* [`BertSpanLabeler`](bert_span_labeler.py) implementats a simple single-span
start-end predictor (that is, a model that predicts two values: a start token
......
......@@ -17,3 +17,4 @@ from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import BertPretrainer
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
......@@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
"""BERT cls-token classifier."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import tensorflow as tf
from official.nlp.modeling import layers
from official.nlp.modeling import networks
......@@ -36,6 +33,9 @@ class BertClassifier(tf.keras.Model):
instantiates a classification network based on the passed `num_classes`
argument. If `num_classes` is set to 1, a regression network is instantiated.
*Note* that the model is constructed by
[Keras Functional API](https://keras.io/guides/functional_api/).
Arguments:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
......@@ -43,23 +43,25 @@ class BertClassifier(tf.keras.Model):
num_classes: Number of classes to predict from the classification network.
initializer: The initializer (if any) to use in the classification networks.
Defaults to a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or
'predictions'.
dropout_rate: The dropout probability of the cls head.
use_encoder_pooler: Whether to use the pooler layer pre-defined inside
the encoder.
"""
def __init__(self,
network,
num_classes,
initializer='glorot_uniform',
output='logits',
dropout_rate=0.1,
use_encoder_pooler=True,
**kwargs):
self._self_setattr_tracking = False
self._network = network
self._config = {
'network': network,
'num_classes': num_classes,
'initializer': initializer,
'output': output,
'use_encoder_pooler': use_encoder_pooler,
}
# We want to use the inputs of the passed network as the inputs to this
......@@ -67,22 +69,36 @@ class BertClassifier(tf.keras.Model):
# when we construct the Model object at the end of init.
inputs = network.inputs
# Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model.
_, cls_output = network(inputs)
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output)
if use_encoder_pooler:
# Because we have a copy of inputs to create this Model object, we can
# invoke the Network object with its own input tensors to start the Model.
_, cls_output = network(inputs)
cls_output = tf.keras.layers.Dropout(rate=dropout_rate)(cls_output)
self.classifier = networks.Classification(
input_width=cls_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
output=output,
name='classification')
predictions = self.classifier(cls_output)
self.classifier = networks.Classification(
input_width=cls_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
output='logits',
name='sentence_prediction')
predictions = self.classifier(cls_output)
else:
sequence_output, _ = network(inputs)
self.classifier = layers.ClassificationHead(
inner_dim=sequence_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
dropout_rate=dropout_rate,
name='sentence_prediction')
predictions = self.classifier(sequence_output)
super(BertClassifier, self).__init__(
inputs=inputs, outputs=predictions, **kwargs)
@property
def checkpoint_items(self):
return dict(encoder=self._network)
def get_config(self):
return self._config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment