Commit 31ca3b97 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

resovle merge conflicts

parents 3e9d886d 7fcd7cba
# NLP Modeling Library
This libary provides a set of Keras primitives (Layers, Networks, and Models)
This library provides a set of Keras primitives (Layers, Networks, and Models)
that can be assembled into transformer-based models. They are
flexible, validated, interoperable, and both TF1 and TF2 compatible.
......@@ -16,6 +16,11 @@ standardized configuration.
* [`losses`](losses) contains common loss computation used in NLP tasks.
Please see the colab
[nlp_modeling_library_intro.ipynb]
(https://colab.sandbox.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb)
for how to build transformer-based NLP models using above primitives.
Besides the pre-defined primitives, it also provides scaffold classes to allow
easy experimentation with noval achitectures, e.g., you don’t need to fork a whole Transformer object to try a different kind of attention primitive, for instance.
......@@ -33,11 +38,9 @@ embedding subnetwork (which will replace the standard embedding logic) and/or a
custom hidden layer (which will replace the Transformer instantiation in the
encoder).
BERT and ALBERT models in this repo are implemented using this library. Code examples can be found in the corresponding model folder.
Please see the colab
[customize_encoder.ipynb]
(https://colab.sandbox.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb)
for how to use scaffold classes to build noval achitectures.
BERT and ALBERT models in this repo are implemented using this library. Code examples can be found in the corresponding model folder.
......@@ -3,11 +3,6 @@
Layers are the fundamental building blocks for NLP models. They can be used to
assemble new layers, networks, or models.
* [DenseEinsum](dense_einsum.py) implements a feedforward network using
tf.einsum. This layer contains the einsum op, the associated weight, and the
logic required to generate the einsum expression for the given
initialization parameters.
* [MultiHeadAttention](attention.py) implements an optionally masked attention
between query, key, value tensors as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If
......
......@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX = string.ascii_lowercase
def _build_attention_equation(qkv_rank, attn_axes):
def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
......@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes):
<query attention dims>, num_heads, channels)
Args:
qkv_rank: the rank of query, key, value tensors.
rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns:
Einsum equations.
"""
target_notation = _CHR_IDX[:qkv_rank]
target_notation = _CHR_IDX[:rank]
# `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,)))
letter_offset = qkv_rank
batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = rank
source_notation = ""
for i in range(qkv_rank):
if i in batch_dims or i == qkv_rank - 1:
for i in range(rank):
if i in batch_dims or i == rank - 1:
source_notation += target_notation[i]
else:
source_notation += _CHR_IDX[letter_offset]
......@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head
attention scores as an additional output argument.
return_attention_scores: bool, if `True`, returns the multi-head attention
scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
......@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def __init__(self,
......@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self._attention_axes = (attention_axes,)
else:
self._attention_axes = attention_axes
self._built_from_signature = False
def get_config(self):
config = {
......@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
inputs_len = len(input_shape)
if inputs_len > 3 or inputs_len < 2:
raise ValueError(
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
"Given length: %d" % inputs_len)
tensor_shapes = tf.nest.map_structure(tf.TensorShape, input_shape)
query_shape = tensor_shapes[0]
value_shape = tensor_shapes[1]
key_shape = tensor_shapes[2] if inputs_len == 3 else value_shape
def _build_from_signature(self, query, value, key=None):
"""Builds layers and variables.
Once the method is called, self._built_from_signature will be set to True.
Args:
query: query tensor or TensorShape.
value: value tensor or TensorShape.
key: key tensor or TensorShape.
"""
self._built_from_signature = True
if hasattr(query, "shape"):
query_shape = tf.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape)
else:
value_shape = value
if key is None:
key_shape = value_shape
elif hasattr(key, "shape"):
key_shape = tf.TensorShape(key.shape)
else:
key_shape = key
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
......@@ -271,7 +293,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
with tf.init_scope():
free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2)
......@@ -302,9 +324,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
**common_kwargs)
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once it
# support mult-head einsum computations.
self._build_attention(output_rank)
# These computations could be wrapped into the keras attention layer once
# it support mult-head einsum computations.
self.build_attention(output_rank)
if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape]
......@@ -320,35 +342,30 @@ class MultiHeadAttention(tf.keras.layers.Layer):
bias_axes=bias_axes if self._use_bias else None,
name="attention_output",
**common_kwargs)
super(MultiHeadAttention, self).build(input_shape)
def _build_attention(self, qkv_rank):
def build_attention(self, rank):
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to
This function builds attributes necessary for `compute_attention` to
costomize attention computation to replace the default dot-product
attention.
Args:
qkv_rank: the rank of query, key, value tensors.
rank: the rank of query, key, value tensors.
"""
if self._attention_axes is None:
self._attention_axes = tuple(range(1, qkv_rank - 2))
self._attention_axes = tuple(range(1, rank - 2))
else:
self._attention_axes = tuple(self._attention_axes)
self._dot_product_equation, self._combine_equation, attn_scores_rank = (
_build_attention_equation(qkv_rank, attn_axes=self._attention_axes))
_build_attention_equation(rank, attn_axes=self._attention_axes))
norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def _compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
def compute_attention(self, query, key, value, attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
......@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention implementation.
Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`.
query: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
......@@ -366,12 +383,14 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
attention_scores = tf.einsum(self._dot_product_equation, key, query)
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S]
......@@ -383,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value_tensor)
attention_scores_dropout, value)
return attention_output, attention_scores
def call(self, inputs, attention_mask=None):
def call(self, query, value, key=None, attention_mask=None):
"""Implements the forward pass.
Size glossary:
......@@ -399,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
* Value (source) attention axes shape (S), the rank must match the target.
Args:
inputs: List of the following tensors:
* query: Query `Tensor` of shape `[B, T, dim]`.
* value: Value `Tensor` of shape `[B, S, dim]`.
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will
use `value` for both `key` and `value`, which is the most common case.
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
......@@ -416,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention
axes.
"""
inputs_len = len(inputs)
if inputs_len > 3 or inputs_len < 2:
raise ValueError(
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
"Given length: %d" % inputs_len)
query = inputs[0]
value = inputs[1]
key = inputs[2] if inputs_len == 3 else value
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, T, N ,H]
query_tensor = self._query_dense(query)
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key_tensor` = [B, S, N, H]
key_tensor = self._key_dense(key)
# `key` = [B, S, N, H]
key = self._key_dense(key)
# `value_tensor` = [B, S, N, H]
value_tensor = self._value_dense(value)
# `value` = [B, S, N, H]
value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention(
query_tensor, key_tensor, value_tensor, attention_mask)
attention_output, attention_scores = self.compute_attention(
query, key, value, attention_mask)
attention_output = self._output_dense(attention_output)
if self._return_attention_scores:
......@@ -453,40 +466,42 @@ class CachedAttention(MultiHeadAttention):
Arguments are the same as `MultiHeadAttention` layer.
"""
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step):
def _update_cache(self, key, value, cache, decode_loop_step):
"""Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values.
if decode_loop_step is not None:
# TPU special case.
key_seq_dim = cache["key"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype),
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key.dtype),
[1, key_seq_dim, 1, 1])
key_tensor = cache["key"] + key_tensor * indices
key = cache["key"] + key * indices
value_seq_dim = cache["value"].shape.as_list()[1]
indices = tf.reshape(
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype),
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value.dtype),
[1, value_seq_dim, 1, 1])
value_tensor = cache["value"] + value_tensor * indices
value = cache["value"] + value * indices
else:
key_tensor = tf.concat(
[tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1)
value_tensor = tf.concat(
[tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)
key = tf.concat([tf.cast(cache["key"], key.dtype), key], axis=1)
value = tf.concat([tf.cast(cache["value"], value.dtype), value], axis=1)
# Update cache
cache["key"] = key_tensor
cache["value"] = value_tensor
cache["key"] = key
cache["value"] = value
return key_tensor, value_tensor
return key, value
def call(self,
inputs,
query,
value,
key=None,
attention_mask=None,
cache=None,
decode_loop_step=None):
from_tensor = inputs[0]
to_tensor = inputs[1]
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
......@@ -494,23 +509,21 @@ class CachedAttention(MultiHeadAttention):
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
# `query` = [B, F, N ,H]
query = self._query_dense(query)
# `key_tensor` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor)
# `key` = [B, T, N, H]
key = self._key_dense(key)
# `value_tensor` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor)
# `value` = [B, T, N, H]
value = self._value_dense(value)
if cache:
key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor,
cache, decode_loop_step)
key, value = self._update_cache(key, value, cache, decode_loop_step)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.einsum(self._dot_product_equation, key, query)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
......@@ -523,7 +536,7 @@ class CachedAttention(MultiHeadAttention):
attention_scores = self._dropout_layer(attention_scores)
# `context_layer` = [B, F, N, H]
attention_output = tf.einsum(self._combine_equation, attention_scores,
value_tensor)
value)
attention_output = self._output_dense(attention_output)
if self._return_attention_scores:
return attention_output, attention_scores, cache
......
......@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value])
output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self):
......@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self):
......@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query])
output, coef = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
......@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor)
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output)
......@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor)
output = test_layer(query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
......@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters(
......@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data)
output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data)
unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(output, unmasked_output)
......@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
......@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# one element.
mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length))
masked_output_data, cache = layer([from_data, from_data], mask_data, cache)
masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data, cache=cache)
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
# Tests inputs without cache.
masked_output_data, cache = layer([from_data, from_data, mask_data])
masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data)
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertIsNone(cache)
......@@ -243,9 +246,11 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32)
# Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data, cache = layer([from_data, from_data],
mask_data,
cache,
masked_output_data, cache = layer(
query=from_data,
value=from_data,
attention_mask=mask_data,
cache=cache,
decode_loop_step=decode_loop_step)
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
......
......@@ -21,6 +21,8 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.python.util import deprecation
_CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"]
......@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer):
`(batch_size, units)`.
"""
@deprecation.deprecated(
None, "DenseEinsum is deprecated. Please use "
"tf.keras.experimental.EinsumDense layer instead.")
def __init__(self,
output_shape,
num_summed_dimensions=1,
......
......@@ -26,7 +26,6 @@ import math
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import masked_softmax
......@@ -67,28 +66,26 @@ class VotingAttention(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, unused_input_shapes):
self._query_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="encdocatt_query")
self._key_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
dtype=self.dtype,
name="encdocatt_key")
bias_constraint=self._bias_constraint)
self._query_dense = tf.keras.layers.experimental.EinsumDense(
"BAE,ENH->BANH",
output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH",
name="query",
**common_kwargs)
self._key_dense = tf.keras.layers.experimental.EinsumDense(
"BAE,ENH->BANH",
output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH",
name="key",
**common_kwargs)
super(VotingAttention, self).build(unused_input_shapes)
def call(self, encoder_outputs, doc_attention_mask):
......@@ -113,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer):
class MultiChannelAttention(attention.MultiHeadAttention):
"""Multi-channel Attention layer.
Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple
cross-attention target sequences.
Introduced in, [Generating Representative Headlines for News Stories
](https://arxiv.org/abs/2001.09386). Expects multiple cross-attention
target sequences.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, A, S, dim]`, where A denotes the
context_attention_weights: Context weights of shape `[B, N, T, A]`, where N
is the number of attention heads. Combines multi-channel sources
context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
"""
def _build_attention(self, qkv_rank):
super(MultiChannelAttention, self)._build_attention(qkv_rank)
def build_attention(self, rank):
super(MultiChannelAttention, self).build_attention(rank)
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self, inputs, attention_mask=None):
from_tensor = inputs[0]
to_tensor = inputs[1]
doc_attention_probs = inputs[2]
def call(self,
query,
value,
key=None,
context_attention_weights=None,
attention_mask=None):
if not self._built_from_signature:
self._build_from_signature(query, value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here:
# B = batch size (number of stories)
# A = num_docs (number of docs)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# F = target sequence length
# T = source sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
query_tensor = self._query_dense(query)
# `key_tensor` = [B, A, T, N, H]
key_tensor = self._key_dense(to_tensor)
key_tensor = self._key_dense(key)
# `value_tensor` = [B, A, T, N, H]
value_tensor = self._value_dense(to_tensor)
value_tensor = self._value_dense(value)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
......@@ -159,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention):
# `context_layer` = [B, F, N, H]
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs,
value_tensor)
attention_output = tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs,
attention_output = tf.einsum("BNFA,BAFNH->BFNH", context_attention_weights,
context_layer)
attention_output = self._output_dense(attention_output)
return attention_output
......@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
mask_data = np.random.randint(2, size=(3, num_docs, 4, 2))
doc_probs = np.random.randint(
2, size=(3, num_heads, 4, num_docs)).astype(float)
outputs = attention_layer([from_data, to_data, doc_probs], mask_data)
outputs = attention_layer(
query=from_data,
value=to_data,
context_attention_weights=doc_probs,
attention_mask=mask_data)
self.assertEqual(outputs.shape, (3, 4, 8))
......
......@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
"hidden_size": self._hidden_size,
"min_timescale": self._min_timescale,
"max_timescale": self._max_timescale,
"length": self._length,
}
base_config = super(RelativePositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
......
......@@ -23,7 +23,6 @@ import gin
import tensorflow as tf
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
@tf.keras.utils.register_keras_serializable(package="Text")
......@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention")
bias_constraint=self._bias_constraint)
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
# Use float32 in layernorm for numeric stability.
......@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis=-1,
epsilon=1e-12,
dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum(
output_shape=self._intermediate_size,
activation=None,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
bias_axes="d",
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
......@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
# Use float32 in layernorm for numeric stability.
......@@ -222,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = target_tensor + self._rezero_a * attention_output
if self._use_layer_norm:
......
......@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
bias_constraint: Constraint for dense layer kernels.
"""
def _build_attention(self, qkv_rank):
def build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection
......@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Args:
qkv_rank: the rank of query, key, value tensors after projection.
"""
super(TalkingHeadsAttention, self)._build_attention(qkv_rank)
super(TalkingHeadsAttention, self).build_attention(qkv_rank)
# Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
......@@ -103,7 +103,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
dtype=self.dtype,
trainable=True)
def _compute_attention(self,
def compute_attention(self,
query_tensor,
key_tensor,
value_tensor,
......
......@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value])
output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self):
......@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self):
......@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query])
output, coef = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
......@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor)
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output)
......@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor)
output = test_layer(
query=query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
......@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query])
output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters(
......@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data)
output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data)
unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(output, unmasked_output)
......
......@@ -23,7 +23,6 @@ import gin
import tensorflow as tf
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers.util import tf_function_if_eager
......@@ -106,21 +105,24 @@ class Transformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention")
bias_constraint=self._bias_constraint)
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
**common_kwargs)
# pylint: disable=protected-access
self._attention_layer.build([input_tensor_shape] * 3)
# Temporarily handling for checkpoint compatible changes.
self._attention_layer._build_from_signature(
query=input_tensor_shape, value=input_tensor_shape)
self._attention_output_dense = self._attention_layer._output_dense
# pylint: enable=protected-access
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
......@@ -132,17 +134,12 @@ class Transformer(tf.keras.layers.Layer):
axis=-1,
epsilon=1e-12,
dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum(
output_shape=self._intermediate_size,
activation=None,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
bias_axes="d",
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
......@@ -151,16 +148,12 @@ class Transformer(tf.keras.layers.Layer):
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
......@@ -211,9 +204,9 @@ class Transformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(target_tensor +
attention_output)
......@@ -312,30 +305,27 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size / self.num_attention_heads)
# Self attention.
self.self_attention = attention.CachedAttention(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_dropout_rate,
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention")
self.self_attention_output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention_output")
bias_constraint=self._bias_constraint)
# Self attention.
self.self_attention = attention.CachedAttention(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_dropout_rate,
name="self_attention",
**common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
**common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate)
self.self_attention_layer_norm = (
......@@ -347,14 +337,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
key_size=self.attention_head_size,
dropout=self.attention_dropout_rate,
output_shape=hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention/encdec")
name="attention/encdec",
**common_kwargs)
self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate)
......@@ -363,29 +347,20 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))
# Feed-forward projection.
self.intermediate_dense = dense_einsum.DenseEinsum(
output_shape=self.intermediate_size,
activation=None,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, self.intermediate_size),
bias_axes="d",
name="intermediate",
**common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation)
self.output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self.output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
**common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12)
......@@ -409,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d" %
len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
self_attention_inputs = [input_tensor, input_tensor]
self_attention_output, cache = self.self_attention(
self_attention_inputs,
query=input_tensor,
value=input_tensor,
attention_mask=self_attention_mask,
cache=cache,
decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
cross_attn_inputs = [self_attention_output, memory]
cross_attn_inputs = dict(
query=self_attention_output,
value=memory,
attention_mask=attention_mask)
if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs.append(inputs[-1])
attention_output = self.encdec_attention(cross_attn_inputs, attention_mask)
cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output +
attention_output)
......
......@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else:
input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor +
attention_output)
......
......@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super(ValidatedAttentionLayer, self).__init__(**kwargs)
self.list = call_list
def call(self, inputs, attention_mask=None):
def call(self, query, value, attention_mask=None):
self.list.append(True)
return super(ValidatedAttentionLayer, self).call(
inputs, attention_mask=attention_mask)
query, value, attention_mask=attention_mask)
def get_config(self):
config = super(ValidatedAttentionLayer, self).get_config()
......
......@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :])
self.assertAllClose(new_output_tensor,
output_tensor[:, 0:1, :],
atol=5e-5,
rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
......
......@@ -4,6 +4,3 @@ Losses contains common loss computation used in NLP tasks.
* `weighted_sparse_categorical_crossentropy_loss` computes per-batch sparse
categorical crossentropy loss.
* `weighted_sparse_categorical_crossentropy_per_example_loss` computes
per-example sparse categorical crossentropy loss.
......@@ -14,4 +14,3 @@
# ==============================================================================
"""Activations package definition. Subject to change."""
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import loss as weighted_sparse_categorical_crossentropy_loss
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import per_example_loss as weighted_sparse_categorical_crossentropy_per_example_loss
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sparse categorical cross-entropy losses."""
"""Weighted sparse categorical cross-entropy losses."""
from __future__ import absolute_import
from __future__ import division
......@@ -43,37 +43,7 @@ def _validate_rank(labels, predictions, weights):
"predictions.shape was %s.") % (labels.shape, predictions.shape))
def per_example_loss(labels, predictions, weights=None):
"""Calculate a per-example sparse categorical crossentropy loss.
This loss function assumes that the predictions are post-softmax.
Args:
labels: The labels to evaluate against. Should be a set of integer indices
ranging from 0 to (vocab_size-1).
predictions: The network predictions. Should have softmax already applied.
weights: An optional weight array of the same shape as the 'labels' array.
If None, all examples will be used.
Returns:
A tensor of shape predictions.shape[:-1] containing the per-example
loss.
"""
# When using these functions with the Keras core API, we will need to squeeze
# the labels tensor - Keras adds a spurious inner dimension.
labels, predictions = _adjust_labels(labels, predictions)
_validate_rank(labels, predictions, weights)
labels_one_hot = tf.one_hot(labels, predictions.shape[-1])
labels_one_hot = tf.cast(labels_one_hot, predictions.dtype)
per_example_loss_data = -tf.reduce_sum(
predictions * labels_one_hot, axis=[-1])
if weights is not None:
weights = tf.cast(weights, per_example_loss_data.dtype)
per_example_loss_data = weights * per_example_loss_data
return per_example_loss_data
def loss(labels, predictions, weights=None):
def loss(labels, predictions, weights=None, from_logits=False):
"""Calculate a per-batch sparse categorical crossentropy loss.
This loss function assumes that the predictions are post-softmax.
......@@ -83,6 +53,7 @@ def loss(labels, predictions, weights=None):
predictions: The network predictions. Should have softmax already applied.
weights: An optional weight array of the same shape as the 'labels' array.
If None, all examples will be used.
from_logits: Whether the input predictions are logits.
Returns:
A loss scalar.
......@@ -95,12 +66,11 @@ def loss(labels, predictions, weights=None):
labels, predictions = _adjust_labels(labels, predictions)
_validate_rank(labels, predictions, weights)
per_example_loss_data = per_example_loss(labels, predictions, weights)
example_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels, predictions, from_logits=from_logits)
if weights is None:
return tf.reduce_mean(per_example_loss_data)
else:
numerator = tf.reduce_sum(per_example_loss_data)
return tf.reduce_mean(example_losses)
weights = tf.cast(weights, predictions.dtype)
denominator = tf.reduce_sum(weights) + 1e-5
return numerator / denominator
return tf.math.divide_no_nan(
tf.reduce_sum(example_losses * weights), tf.reduce_sum(weights))
......@@ -53,8 +53,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# Create a maskedLM from the transformer stack.
test_layer = layers.MaskedLM(
embedding_table=xformer_stack.get_embedding_table(),
output=output)
embedding_table=xformer_stack.get_embedding_table(), output=output)
# Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
......@@ -63,123 +62,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
output = test_layer(lm_input_tensor, masked_positions=masked_lm_positions)
return tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
def create_classification_model(self, input_width, num_classes):
test_object = networks.Classification(
input_width=input_width, num_classes=num_classes)
# Create a 2-dimensional input (the first dimension is implicit).
pooled_data = tf.keras.Input(shape=(input_width,), dtype=tf.float32)
output = test_object(pooled_data)
return tf.keras.Model(pooled_data, output)
def test_per_example_loss_3d_input(self):
"""Test per-example loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
model = self.create_lm_model(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Get the output of the masked LM.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
output_data = model.predict([lm_input_data, masked_position_data])
# Calculate per-example loss.
labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
# Per-example loss data should have one value per prediction, and those
# values shouldn't be zero in this case (as we're using random data).
expected_shape = [batch_size, num_predictions]
self.assertEqual(expected_shape, per_example_loss_data.shape.as_list())
self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_per_example_loss_2d_input(self):
"""Test per-example loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per example loss.
labels = np.random.randint(num_classes, size=(batch_size))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
# Per-example loss data should have one value per batch item, and those
# values shouldn't be zero in this case (as we're using random data).
self.assertEqual([batch_size], per_example_loss_data.shape.as_list())
self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_per_example_loss_weights_3d_input(self):
"""Test weighted per-example loss with a 3-d input, from a masked LM."""
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
model = self.create_lm_model(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Get the output of the masked LM.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
output_data = model.predict([lm_input_data, masked_position_data])
# Calculate per-example loss with weights.
labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
weights = np.random.randint(2, size=(batch_size, num_predictions))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
# Weighted per-example loss data should be equivalent to multiplying the
# loss tensor by the weights tensor.
expected_weighted_loss = per_example_loss_data * weights
self.assertAllClose(expected_weighted_loss, per_example_loss_data)
def test_per_example_loss_weights_2d_input(self):
"""Test weighted per-example loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per-example loss with weights.
labels = np.random.randint(num_classes, size=(batch_size))
weights = np.random.randint(2, size=(batch_size))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
# Weighted per-example loss data should be equivalent to multiplying the
# loss tensor by the weights tensor.
expected_weighted_loss = per_example_loss_data * weights
self.assertAllClose(expected_weighted_loss, per_example_loss_data)
def test_loss_3d_input(self):
"""Test overall loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100
......@@ -213,26 +95,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_loss_2d_input(self):
"""Test overall loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per example loss.
labels = np.random.randint(num_classes, size=(batch_size))
loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels)
# Loss data should have one value only, and that value shouldn't be zero in
# this case (as we're using random data).
self.assertNotAllClose(0, loss_data)
def test_loss_weights_3d_input(self):
"""Test masked loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100
......@@ -262,26 +124,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# Because the tensor is fully masked, the loss should be 0.
self.assertAllClose(0, weighted_loss_data)
def test_loss_weights_2d_input(self):
"""Test masked loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate a fully masked weight tensor. This should give a loss of zero.
labels = np.random.randint(num_classes, size=(batch_size))
null_weights = np.zeros((batch_size))
weighted_loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=null_weights)
# Because the tensor is fully masked, the loss should be 0.
self.assertAllClose(0, weighted_loss_data)
def test_mismatched_predictions_and_labels_ranks_squeezes(self):
"""Test that the loss asserts when rank(predictions)-1 != rank(labels)."""
batch_size = 3
......@@ -289,7 +131,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
labels = np.random.randint(10, size=(batch_size, 1))
# All that this test tests is that the squeeze is successful.
_ = weighted_sparse_categorical_crossentropy.per_example_loss(
_ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels)
def test_mismatched_weights_and_labels_ranks_fail(self):
......@@ -299,9 +141,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
labels = np.random.randint(10, size=(batch_size, 10))
weights = np.random.randint(2, size=(batch_size))
with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
_ = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
_ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights)
......@@ -317,8 +156,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# We're not trying to validate numerical correctness, just ensure that
# we can in fact pass tensors to these functions without causing runtime
# errors from the shape checking code.
_ = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
_ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights)
......@@ -338,20 +175,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
[-2.7760355, -1.8219438, -3.0924666, -1.0779881, -0.9407509]]])
labels = np.array([[4, 0], [2, 2], [2, 1]])
# Validate that per_example loss calculations are the same.
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
expected_per_example_loss_data = [[1.2923571, 2.7117882],
[2.287932, 2.287932],
[3.0924666, 1.8219438]]
self.assertAllClose(expected_per_example_loss_data, per_example_loss_data)
# Validate that overall loss calculations are the same.
weights = np.array([[1, 0], [0, 0], [0, 0]])
loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights)
predictions=output_data,
labels=labels,
weights=weights,
from_logits=True)
expected_loss_data = 1.2923441
self.assertAllClose(expected_loss_data, loss_data)
self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3)
def test_legacy_classification_loss_compatibility(self):
"""Test to validate computational correctness during refactors."""
......@@ -362,19 +194,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
[-1.6975292e-03, -6.4009643e+00, -1.0226612e+01]])
labels = np.array([2, 1])
# Validate that per_example loss calculations are the same.
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
expected_per_example_loss_data = [6.4434357, 6.4009643]
self.assertAllClose(expected_per_example_loss_data, per_example_loss_data)
# Validate that overall loss calculations are the same.
weights = None
loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights)
predictions=output_data,
labels=labels,
weights=weights,
from_logits=True)
expected_loss_data = 6.4222
self.assertAllClose(expected_loss_data, loss_data)
self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3)
if __name__ == "__main__":
tf.test.main()
......@@ -17,3 +17,4 @@ from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import BertPretrainer
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment