Commit 31ca3b97 authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

resovle merge conflicts

parents 3e9d886d 7fcd7cba
# NLP Modeling Library # NLP Modeling Library
This libary provides a set of Keras primitives (Layers, Networks, and Models) This library provides a set of Keras primitives (Layers, Networks, and Models)
that can be assembled into transformer-based models. They are that can be assembled into transformer-based models. They are
flexible, validated, interoperable, and both TF1 and TF2 compatible. flexible, validated, interoperable, and both TF1 and TF2 compatible.
...@@ -16,6 +16,11 @@ standardized configuration. ...@@ -16,6 +16,11 @@ standardized configuration.
* [`losses`](losses) contains common loss computation used in NLP tasks. * [`losses`](losses) contains common loss computation used in NLP tasks.
Please see the colab
[nlp_modeling_library_intro.ipynb]
(https://colab.sandbox.google.com/github/tensorflow/models/blob/master/official/colab/nlp/nlp_modeling_library_intro.ipynb)
for how to build transformer-based NLP models using above primitives.
Besides the pre-defined primitives, it also provides scaffold classes to allow Besides the pre-defined primitives, it also provides scaffold classes to allow
easy experimentation with noval achitectures, e.g., you don’t need to fork a whole Transformer object to try a different kind of attention primitive, for instance. easy experimentation with noval achitectures, e.g., you don’t need to fork a whole Transformer object to try a different kind of attention primitive, for instance.
...@@ -33,11 +38,9 @@ embedding subnetwork (which will replace the standard embedding logic) and/or a ...@@ -33,11 +38,9 @@ embedding subnetwork (which will replace the standard embedding logic) and/or a
custom hidden layer (which will replace the Transformer instantiation in the custom hidden layer (which will replace the Transformer instantiation in the
encoder). encoder).
BERT and ALBERT models in this repo are implemented using this library. Code examples can be found in the corresponding model folder. Please see the colab
[customize_encoder.ipynb]
(https://colab.sandbox.google.com/github/tensorflow/models/blob/master/official/colab/nlp/customize_encoder.ipynb)
for how to use scaffold classes to build noval achitectures.
BERT and ALBERT models in this repo are implemented using this library. Code examples can be found in the corresponding model folder.
...@@ -3,11 +3,6 @@ ...@@ -3,11 +3,6 @@
Layers are the fundamental building blocks for NLP models. They can be used to Layers are the fundamental building blocks for NLP models. They can be used to
assemble new layers, networks, or models. assemble new layers, networks, or models.
* [DenseEinsum](dense_einsum.py) implements a feedforward network using
tf.einsum. This layer contains the einsum op, the associated weight, and the
logic required to generate the einsum expression for the given
initialization parameters.
* [MultiHeadAttention](attention.py) implements an optionally masked attention * [MultiHeadAttention](attention.py) implements an optionally masked attention
between query, key, value tensors as described in between query, key, value tensors as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If
......
...@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense ...@@ -33,7 +33,7 @@ EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX = string.ascii_lowercase _CHR_IDX = string.ascii_lowercase
def _build_attention_equation(qkv_rank, attn_axes): def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation. """Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as: Query, key, value inputs after projection are expected to have the shape as:
...@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes): ...@@ -50,19 +50,19 @@ def _build_attention_equation(qkv_rank, attn_axes):
<query attention dims>, num_heads, channels) <query attention dims>, num_heads, channels)
Args: Args:
qkv_rank: the rank of query, key, value tensors. rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention. attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns: Returns:
Einsum equations. Einsum equations.
""" """
target_notation = _CHR_IDX[:qkv_rank] target_notation = _CHR_IDX[:rank]
# `batch_dims` includes the head dim. # `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,))) batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = qkv_rank letter_offset = rank
source_notation = "" source_notation = ""
for i in range(qkv_rank): for i in range(rank):
if i in batch_dims or i == qkv_rank - 1: if i in batch_dims or i == rank - 1:
source_notation += target_notation[i] source_notation += target_notation[i]
else: else:
source_notation += _CHR_IDX[letter_offset] source_notation += _CHR_IDX[letter_offset]
...@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -167,8 +167,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
sequence dims. If not specified, projects back to the key feature dim. sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features. attention over all axes, but batch, heads, and features.
return_attention_scores: bool, if `True`, returns the multi-head return_attention_scores: bool, if `True`, returns the multi-head attention
attention scores as an additional output argument. scores as an additional output argument.
kernel_initializer: Initializer for dense layer kernels. kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases. bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels. kernel_regularizer: Regularizer for dense layer kernels.
...@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -176,6 +176,13 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity. activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels. kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
""" """
def __init__(self, def __init__(self,
...@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -214,6 +221,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self._attention_axes = (attention_axes,) self._attention_axes = (attention_axes,)
else: else:
self._attention_axes = attention_axes self._attention_axes = attention_axes
self._built_from_signature = False
def get_config(self): def get_config(self):
config = { config = {
...@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -251,17 +259,31 @@ class MultiHeadAttention(tf.keras.layers.Layer):
base_config = super(MultiHeadAttention, self).get_config() base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape): def _build_from_signature(self, query, value, key=None):
inputs_len = len(input_shape) """Builds layers and variables.
if inputs_len > 3 or inputs_len < 2:
raise ValueError( Once the method is called, self._built_from_signature will be set to True.
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. " Args:
"Given length: %d" % inputs_len) query: query tensor or TensorShape.
tensor_shapes = tf.nest.map_structure(tf.TensorShape, input_shape) value: value tensor or TensorShape.
query_shape = tensor_shapes[0] key: key tensor or TensorShape.
value_shape = tensor_shapes[1] """
key_shape = tensor_shapes[2] if inputs_len == 3 else value_shape self._built_from_signature = True
if hasattr(query, "shape"):
query_shape = tf.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"):
value_shape = tf.TensorShape(value.shape)
else:
value_shape = value
if key is None:
key_shape = value_shape
elif hasattr(key, "shape"):
key_shape = tf.TensorShape(key.shape)
else:
key_shape = key
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
...@@ -271,84 +293,79 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -271,84 +293,79 @@ class MultiHeadAttention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint) bias_constraint=self._bias_constraint)
with tf.init_scope():
free_dims = query_shape.rank - 1 free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2) free_dims, bound_dims=1, output_dims=2)
self._query_dense = EinsumDense( self._query_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]), [self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="query", name="query",
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2) key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = EinsumDense( self._key_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_size]), [self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="key", name="key",
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
value_shape.rank - 1, bound_dims=1, output_dims=2) value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = EinsumDense( self._value_dense = EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_size]), [self._num_heads, self._value_size]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="value", name="value",
**common_kwargs) **common_kwargs)
# Builds the attention computations for multi-head dot product attention. # Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once it # These computations could be wrapped into the keras attention layer once
# support mult-head einsum computations. # it support mult-head einsum computations.
self._build_attention(output_rank) self.build_attention(output_rank)
if self._output_shape: if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized): if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape] output_shape = [self._output_shape]
else:
output_shape = self._output_shape
else: else:
output_shape = self._output_shape output_shape = [query_shape[-1]]
else: einsum_equation, bias_axes, output_rank = _build_proj_equation(
output_shape = [query_shape[-1]] free_dims, bound_dims=2, output_dims=len(output_shape))
einsum_equation, bias_axes, output_rank = _build_proj_equation( self._output_dense = EinsumDense(
free_dims, bound_dims=2, output_dims=len(output_shape)) einsum_equation,
self._output_dense = EinsumDense( output_shape=_get_output_shape(output_rank - 1, output_shape),
einsum_equation, bias_axes=bias_axes if self._use_bias else None,
output_shape=_get_output_shape(output_rank - 1, output_shape), name="attention_output",
bias_axes=bias_axes if self._use_bias else None, **common_kwargs)
name="attention_output",
**common_kwargs) def build_attention(self, rank):
super(MultiHeadAttention, self).build(input_shape)
def _build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations. """Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to This function builds attributes necessary for `compute_attention` to
costomize attention computation to replace the default dot-product costomize attention computation to replace the default dot-product
attention. attention.
Args: Args:
qkv_rank: the rank of query, key, value tensors. rank: the rank of query, key, value tensors.
""" """
if self._attention_axes is None: if self._attention_axes is None:
self._attention_axes = tuple(range(1, qkv_rank - 2)) self._attention_axes = tuple(range(1, rank - 2))
else: else:
self._attention_axes = tuple(self._attention_axes) self._attention_axes = tuple(self._attention_axes)
self._dot_product_equation, self._combine_equation, attn_scores_rank = ( self._dot_product_equation, self._combine_equation, attn_scores_rank = (
_build_attention_equation(qkv_rank, attn_axes=self._attention_axes)) _build_attention_equation(rank, attn_axes=self._attention_axes))
norm_axes = tuple( norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = masked_softmax.MaskedSoftmax( self._masked_softmax = masked_softmax.MaskedSoftmax(
mask_expansion_axes=[1], normalization_axes=norm_axes) mask_expansion_axes=[1], normalization_axes=norm_axes)
self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout) self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout)
def _compute_attention(self, def compute_attention(self, query, key, value, attention_mask=None):
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors. """Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected This function defines the computation inside `call` with projected
...@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -356,9 +373,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention implementation. attention implementation.
Args: Args:
query_tensor: Projected query `Tensor` of shape `[B, T, N, key_size]`. query: Projected query `Tensor` of shape `[B, T, N, key_size]`.
key_tensor: Projected key `Tensor` of shape `[B, T, N, key_size]`. key: Projected key `Tensor` of shape `[B, T, N, key_size]`.
value_tensor: Projected value `Tensor` of shape `[B, T, N, value_size]`. value: Projected value `Tensor` of shape `[B, T, N, value_size]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions. attention to certain positions.
...@@ -366,12 +383,14 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -366,12 +383,14 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention_output: Multi-headed outputs of attention computation. attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights. attention_scores: Multi-headed attention weights.
""" """
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, attention_scores = tf.einsum(self._dot_product_equation, key, query)
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S] # `attention_scores` = [B, N, T, S]
...@@ -383,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -383,10 +402,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# `context_layer` = [B, T, N, H] # `context_layer` = [B, T, N, H]
attention_output = tf.einsum(self._combine_equation, attention_output = tf.einsum(self._combine_equation,
attention_scores_dropout, value_tensor) attention_scores_dropout, value)
return attention_output, attention_scores return attention_output, attention_scores
def call(self, inputs, attention_mask=None): def call(self, query, value, key=None, attention_mask=None):
"""Implements the forward pass. """Implements the forward pass.
Size glossary: Size glossary:
...@@ -399,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -399,11 +418,10 @@ class MultiHeadAttention(tf.keras.layers.Layer):
* Value (source) attention axes shape (S), the rank must match the target. * Value (source) attention axes shape (S), the rank must match the target.
Args: Args:
inputs: List of the following tensors: query: Query `Tensor` of shape `[B, T, dim]`.
* query: Query `Tensor` of shape `[B, T, dim]`. value: Value `Tensor` of shape `[B, S, dim]`.
* value: Value `Tensor` of shape `[B, S, dim]`. key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
* key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will `value` for both `key` and `value`, which is the most common case.
use `value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions. attention to certain positions.
...@@ -416,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer): ...@@ -416,29 +434,24 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention attention
axes. axes.
""" """
inputs_len = len(inputs) if not self._built_from_signature:
if inputs_len > 3 or inputs_len < 2: self._build_from_signature(query=query, value=value, key=key)
raise ValueError( if key is None:
"Expects inputs list of length 2 or 3, namely [query, value] or " key = value
"[query, value, key]. "
"Given length: %d" % inputs_len)
query = inputs[0]
value = inputs[1]
key = inputs[2] if inputs_len == 3 else value
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, T, N ,H] # `query` = [B, T, N ,H]
query_tensor = self._query_dense(query) query = self._query_dense(query)
# `key_tensor` = [B, S, N, H] # `key` = [B, S, N, H]
key_tensor = self._key_dense(key) key = self._key_dense(key)
# `value_tensor` = [B, S, N, H] # `value` = [B, S, N, H]
value_tensor = self._value_dense(value) value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention( attention_output, attention_scores = self.compute_attention(
query_tensor, key_tensor, value_tensor, attention_mask) query, key, value, attention_mask)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
if self._return_attention_scores: if self._return_attention_scores:
...@@ -453,40 +466,42 @@ class CachedAttention(MultiHeadAttention): ...@@ -453,40 +466,42 @@ class CachedAttention(MultiHeadAttention):
Arguments are the same as `MultiHeadAttention` layer. Arguments are the same as `MultiHeadAttention` layer.
""" """
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step): def _update_cache(self, key, value, cache, decode_loop_step):
"""Updates cache states and gets full-length key/value tensors.""" """Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values. # Combines cached keys and values with new keys and values.
if decode_loop_step is not None: if decode_loop_step is not None:
# TPU special case. # TPU special case.
key_seq_dim = cache["key"].shape.as_list()[1] key_seq_dim = cache["key"].shape.as_list()[1]
indices = tf.reshape( indices = tf.reshape(
tf.one_hot(decode_loop_step, key_seq_dim, dtype=key_tensor.dtype), tf.one_hot(decode_loop_step, key_seq_dim, dtype=key.dtype),
[1, key_seq_dim, 1, 1]) [1, key_seq_dim, 1, 1])
key_tensor = cache["key"] + key_tensor * indices key = cache["key"] + key * indices
value_seq_dim = cache["value"].shape.as_list()[1] value_seq_dim = cache["value"].shape.as_list()[1]
indices = tf.reshape( indices = tf.reshape(
tf.one_hot(decode_loop_step, value_seq_dim, dtype=value_tensor.dtype), tf.one_hot(decode_loop_step, value_seq_dim, dtype=value.dtype),
[1, value_seq_dim, 1, 1]) [1, value_seq_dim, 1, 1])
value_tensor = cache["value"] + value_tensor * indices value = cache["value"] + value * indices
else: else:
key_tensor = tf.concat( key = tf.concat([tf.cast(cache["key"], key.dtype), key], axis=1)
[tf.cast(cache["key"], key_tensor.dtype), key_tensor], axis=1) value = tf.concat([tf.cast(cache["value"], value.dtype), value], axis=1)
value_tensor = tf.concat(
[tf.cast(cache["value"], value_tensor.dtype), value_tensor], axis=1)
# Update cache # Update cache
cache["key"] = key_tensor cache["key"] = key
cache["value"] = value_tensor cache["value"] = value
return key_tensor, value_tensor return key, value
def call(self, def call(self,
inputs, query,
value,
key=None,
attention_mask=None, attention_mask=None,
cache=None, cache=None,
decode_loop_step=None): decode_loop_step=None):
from_tensor = inputs[0] if not self._built_from_signature:
to_tensor = inputs[1] self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here: # Scalar dimensions referenced here:
# B = batch size (number of sequences) # B = batch size (number of sequences)
...@@ -494,23 +509,21 @@ class CachedAttention(MultiHeadAttention): ...@@ -494,23 +509,21 @@ class CachedAttention(MultiHeadAttention):
# T = `to_tensor` sequence length # T = `to_tensor` sequence length
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, F, N ,H] # `query` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor) query = self._query_dense(query)
# `key_tensor` = [B, T, N, H] # `key` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor) key = self._key_dense(key)
# `value_tensor` = [B, T, N, H] # `value` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor) value = self._value_dense(value)
if cache: if cache:
key_tensor, value_tensor = self._update_cache(key_tensor, value_tensor, key, value = self._update_cache(key, value, cache, decode_loop_step)
cache, decode_loop_step)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
attention_scores = tf.einsum(self._dot_product_equation, key_tensor, attention_scores = tf.einsum(self._dot_product_equation, key, query)
query_tensor)
attention_scores = tf.multiply(attention_scores, attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size))) 1.0 / math.sqrt(float(self._key_size)))
...@@ -523,7 +536,7 @@ class CachedAttention(MultiHeadAttention): ...@@ -523,7 +536,7 @@ class CachedAttention(MultiHeadAttention):
attention_scores = self._dropout_layer(attention_scores) attention_scores = self._dropout_layer(attention_scores)
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
attention_output = tf.einsum(self._combine_equation, attention_scores, attention_output = tf.einsum(self._combine_equation, attention_scores,
value_tensor) value)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
if self._return_attention_scores: if self._return_attention_scores:
return attention_output, attention_scores, cache return attention_output, attention_scores, cache
......
...@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -45,7 +45,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80)) value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value]) output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims) self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
...@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -53,7 +53,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64) test_layer = attention.MultiHeadAttention(num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self): def test_attention_scores(self):
...@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -62,7 +62,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True) num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query]) output, coef = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40]) self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
...@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -76,7 +76,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8)) query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8)) value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor) output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer. # Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output) model = tf.keras.Model([query, value, mask_tensor], output)
...@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -100,7 +100,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V. # Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8)) key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor) output = test_layer(query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output) model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data]) masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
...@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -125,7 +125,7 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters( @parameterized.named_parameters(
...@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): ...@@ -147,11 +147,12 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least # Invoke the data with a random set of mask data. This should mask at least
# one element. # one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool") mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data) output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked). # Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape) null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data) unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the # Because one data is masked and one is not, the outputs should not be the
# same. # same.
self.assertNotAllClose(output, unmasked_output) self.assertNotAllClose(output, unmasked_output)
...@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase): ...@@ -180,7 +181,7 @@ class AttentionSubclassTest(keras_parameterized.TestCase):
key_size=64) key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
...@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -216,12 +217,14 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# one element. # one element.
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length)) 2, size=(batch_size, from_seq_length, from_seq_length))
masked_output_data, cache = layer([from_data, from_data], mask_data, cache) masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data, cache=cache)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2)) self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
# Tests inputs without cache. # Tests inputs without cache.
masked_output_data, cache = layer([from_data, from_data, mask_data]) masked_output_data, cache = layer(
query=from_data, value=from_data, attention_mask=mask_data)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertIsNone(cache) self.assertIsNone(cache)
...@@ -243,10 +246,12 @@ class CachedAttentionTest(keras_parameterized.TestCase): ...@@ -243,10 +246,12 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data = np.random.randint( mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32) 2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32)
# Testing the invocation directly as Keras cannot consume inputs correctly. # Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data, cache = layer([from_data, from_data], masked_output_data, cache = layer(
mask_data, query=from_data,
cache, value=from_data,
decode_loop_step=decode_loop_step) attention_mask=mask_data,
cache=cache,
decode_loop_step=decode_loop_step)
self.assertEqual(masked_output_data.shape, (3, 4, 8)) self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2)) self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
......
...@@ -21,6 +21,8 @@ from __future__ import print_function ...@@ -21,6 +21,8 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import deprecation
_CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"] _CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"]
...@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer): ...@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer):
`(batch_size, units)`. `(batch_size, units)`.
""" """
@deprecation.deprecated(
None, "DenseEinsum is deprecated. Please use "
"tf.keras.experimental.EinsumDense layer instead.")
def __init__(self, def __init__(self,
output_shape, output_shape,
num_summed_dimensions=1, num_summed_dimensions=1,
......
...@@ -26,7 +26,6 @@ import math ...@@ -26,7 +26,6 @@ import math
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import masked_softmax from official.nlp.modeling.layers import masked_softmax
...@@ -67,28 +66,26 @@ class VotingAttention(tf.keras.layers.Layer): ...@@ -67,28 +66,26 @@ class VotingAttention(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
self._query_dense = dense_einsum.DenseEinsum( common_kwargs = dict(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
dtype=self.dtype, self._query_dense = tf.keras.layers.experimental.EinsumDense(
name="encdocatt_query") "BAE,ENH->BANH",
self._key_dense = dense_einsum.DenseEinsum( output_shape=(None, self._num_heads, self._head_size),
output_shape=(self._num_heads, self._head_size), bias_axes="NH",
kernel_initializer=self._kernel_initializer, name="query",
bias_initializer=self._bias_initializer, **common_kwargs)
kernel_regularizer=self._kernel_regularizer, self._key_dense = tf.keras.layers.experimental.EinsumDense(
bias_regularizer=self._bias_regularizer, "BAE,ENH->BANH",
activity_regularizer=self._activity_regularizer, output_shape=(None, self._num_heads, self._head_size),
kernel_constraint=self._kernel_constraint, bias_axes="NH",
bias_constraint=self._bias_constraint, name="key",
dtype=self.dtype, **common_kwargs)
name="encdocatt_key")
super(VotingAttention, self).build(unused_input_shapes) super(VotingAttention, self).build(unused_input_shapes)
def call(self, encoder_outputs, doc_attention_mask): def call(self, encoder_outputs, doc_attention_mask):
...@@ -113,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer): ...@@ -113,34 +110,52 @@ class VotingAttention(tf.keras.layers.Layer):
class MultiChannelAttention(attention.MultiHeadAttention): class MultiChannelAttention(attention.MultiHeadAttention):
"""Multi-channel Attention layer. """Multi-channel Attention layer.
Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple Introduced in, [Generating Representative Headlines for News Stories
cross-attention target sequences. ](https://arxiv.org/abs/2001.09386). Expects multiple cross-attention
target sequences.
Call args:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, A, S, dim]`, where A denotes the
context_attention_weights: Context weights of shape `[B, N, T, A]`, where N
is the number of attention heads. Combines multi-channel sources
context tensors according to the distribution among channels.
key: Optional key `Tensor` of shape `[B, A, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
""" """
def _build_attention(self, qkv_rank): def build_attention(self, rank):
super(MultiChannelAttention, self)._build_attention(qkv_rank) super(MultiChannelAttention, self).build_attention(rank)
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2]) self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self, inputs, attention_mask=None): def call(self,
from_tensor = inputs[0] query,
to_tensor = inputs[1] value,
doc_attention_probs = inputs[2] key=None,
context_attention_weights=None,
attention_mask=None):
if not self._built_from_signature:
self._build_from_signature(query, value, key=key)
if key is None:
key = value
# Scalar dimensions referenced here: # Scalar dimensions referenced here:
# B = batch size (number of stories) # B = batch size (number of stories)
# A = num_docs (number of docs) # A = num_docs (number of docs)
# F = `from_tensor` sequence length # F = target sequence length
# T = `to_tensor` sequence length # T = source sequence length
# N = `num_attention_heads` # N = `num_attention_heads`
# H = `size_per_head` # H = `size_per_head`
# `query_tensor` = [B, F, N ,H] # `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor) query_tensor = self._query_dense(query)
# `key_tensor` = [B, A, T, N, H] # `key_tensor` = [B, A, T, N, H]
key_tensor = self._key_dense(to_tensor) key_tensor = self._key_dense(key)
# `value_tensor` = [B, A, T, N, H] # `value_tensor` = [B, A, T, N, H]
value_tensor = self._value_dense(to_tensor) value_tensor = self._value_dense(value)
# Take the dot product between "query" and "key" to get the raw # Take the dot product between "query" and "key" to get the raw
# attention scores. # attention scores.
...@@ -159,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention): ...@@ -159,7 +174,7 @@ class MultiChannelAttention(attention.MultiHeadAttention):
# `context_layer` = [B, F, N, H] # `context_layer` = [B, F, N, H]
context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs, context_layer = tf.einsum("BANFT,BATNH->BAFNH", attention_probs,
value_tensor) value_tensor)
attention_output = tf.einsum("BNFA,BAFNH->BFNH", doc_attention_probs, attention_output = tf.einsum("BNFA,BAFNH->BFNH", context_attention_weights,
context_layer) context_layer)
attention_output = self._output_dense(attention_output) attention_output = self._output_dense(attention_output)
return attention_output return attention_output
...@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase): ...@@ -48,7 +48,11 @@ class MultiChannelAttentionTest(tf.test.TestCase):
mask_data = np.random.randint(2, size=(3, num_docs, 4, 2)) mask_data = np.random.randint(2, size=(3, num_docs, 4, 2))
doc_probs = np.random.randint( doc_probs = np.random.randint(
2, size=(3, num_heads, 4, num_docs)).astype(float) 2, size=(3, num_heads, 4, num_docs)).astype(float)
outputs = attention_layer([from_data, to_data, doc_probs], mask_data) outputs = attention_layer(
query=from_data,
value=to_data,
context_attention_weights=doc_probs,
attention_mask=mask_data)
self.assertEqual(outputs.shape, (3, 4, 8)) self.assertEqual(outputs.shape, (3, 4, 8))
......
...@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer): ...@@ -160,7 +160,6 @@ class RelativePositionEmbedding(tf.keras.layers.Layer):
"hidden_size": self._hidden_size, "hidden_size": self._hidden_size,
"min_timescale": self._min_timescale, "min_timescale": self._min_timescale,
"max_timescale": self._max_timescale, "max_timescale": self._max_timescale,
"length": self._length,
} }
base_config = super(RelativePositionEmbedding, self).get_config() base_config = super(RelativePositionEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
......
...@@ -23,7 +23,6 @@ import gin ...@@ -23,7 +23,6 @@ import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention " "The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
name="self_attention") self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
...@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=self._intermediate_size, "abc,cd->abd",
activation=None, output_shape=(None, self._intermediate_size),
kernel_initializer=self._kernel_initializer, bias_axes="d",
bias_initializer=self._bias_initializer, name="intermediate",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge # bfloat16 causes BERT with the LAMB optimizer to not converge
...@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._intermediate_activation, dtype=policy)
self._output_dense = dense_einsum.DenseEinsum( self._output_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=hidden_size, "abc,cd->abd",
kernel_initializer=self._kernel_initializer, output_shape=(None, hidden_size),
bias_initializer=self._bias_initializer, bias_axes="d",
kernel_regularizer=self._kernel_regularizer, name="output",
bias_regularizer=self._bias_regularizer, **common_kwargs)
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
...@@ -222,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -222,9 +213,9 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
else: else:
target_tensor = input_tensor target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask) attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = target_tensor + self._rezero_a * attention_output attention_output = target_tensor + self._rezero_a * attention_output
if self._use_layer_norm: if self._use_layer_norm:
......
...@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -58,7 +58,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
bias_constraint: Constraint for dense layer kernels. bias_constraint: Constraint for dense layer kernels.
""" """
def _build_attention(self, qkv_rank): def build_attention(self, qkv_rank):
"""Builds multi-head dot-product attention computations. """Builds multi-head dot-product attention computations.
This function overrides base class to create additional linear projection This function overrides base class to create additional linear projection
...@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -67,7 +67,7 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
Args: Args:
qkv_rank: the rank of query, key, value tensors after projection. qkv_rank: the rank of query, key, value tensors after projection.
""" """
super(TalkingHeadsAttention, self)._build_attention(qkv_rank) super(TalkingHeadsAttention, self).build_attention(qkv_rank)
# Build an equation: # Build an equation:
# (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) -> # (<batch_dims>, num_heads_a, ...),(num_heads_a, num_heads_b) ->
...@@ -103,11 +103,11 @@ class TalkingHeadsAttention(attention.MultiHeadAttention): ...@@ -103,11 +103,11 @@ class TalkingHeadsAttention(attention.MultiHeadAttention):
dtype=self.dtype, dtype=self.dtype,
trainable=True) trainable=True)
def _compute_attention(self, def compute_attention(self,
query_tensor, query_tensor,
key_tensor, key_tensor,
value_tensor, value_tensor,
attention_mask=None): attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors. """Applies Dot-product attention with query, key, value tensors.
This function overrides base class to apply additional linear projection This function overrides base class to apply additional linear projection
......
...@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -46,7 +46,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
value = tf.keras.Input(shape=(20, 80)) value = tf.keras.Input(shape=(20, 80))
output = test_layer([query, value]) output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims) self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self): def test_non_masked_self_attention(self):
...@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -55,7 +55,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64) num_heads=12, key_size=64)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self): def test_attention_scores(self):
...@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -64,7 +64,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
num_heads=12, key_size=64, return_attention_scores=True) num_heads=12, key_size=64, return_attention_scores=True)
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output, coef = test_layer([query, query]) output, coef = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40]) self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
...@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -78,7 +78,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
query = tf.keras.Input(shape=(4, 8)) query = tf.keras.Input(shape=(4, 8))
value = tf.keras.Input(shape=(2, 8)) value = tf.keras.Input(shape=(2, 8))
mask_tensor = tf.keras.Input(shape=(4, 2)) mask_tensor = tf.keras.Input(shape=(4, 2))
output = test_layer([query, value], mask_tensor) output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer. # Create a model containing the test layer.
model = tf.keras.Model([query, value, mask_tensor], output) model = tf.keras.Model([query, value, mask_tensor], output)
...@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -102,7 +102,8 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Tests the layer with three inputs: Q, K, V. # Tests the layer with three inputs: Q, K, V.
key = tf.keras.Input(shape=(2, 8)) key = tf.keras.Input(shape=(2, 8))
output = test_layer([query, value, key], mask_tensor) output = test_layer(
query=query, value=value, key=key, attention_mask=mask_tensor)
model = tf.keras.Model([query, value, key, mask_tensor], output) model = tf.keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data]) masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
...@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -127,7 +128,7 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)) kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit). # Create a 3-dimensional input (the first dimension is implicit).
query = tf.keras.Input(shape=(40, 80)) query = tf.keras.Input(shape=(40, 80))
output = test_layer([query, query]) output = test_layer(query=query, value=query)
self.assertEqual(output.shape.as_list(), [None, 40, 80]) self.assertEqual(output.shape.as_list(), [None, 40, 80])
@parameterized.named_parameters( @parameterized.named_parameters(
...@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase): ...@@ -149,11 +150,12 @@ class TalkingHeadsAttentionTest(keras_parameterized.TestCase):
# Invoke the data with a random set of mask data. This should mask at least # Invoke the data with a random set of mask data. This should mask at least
# one element. # one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool") mask_data = np.random.randint(2, size=mask_shape).astype("bool")
output = test_layer([query, value], mask_data) output = test_layer(query=query, value=value, attention_mask=mask_data)
# Invoke the same data, but with a null mask (where no elements are masked). # Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape) null_mask_data = np.ones(mask_shape)
unmasked_output = test_layer([query, value], null_mask_data) unmasked_output = test_layer(
query=query, value=value, attention_mask=null_mask_data)
# Because one data is masked and one is not, the outputs should not be the # Because one data is masked and one is not, the outputs should not be the
# same. # same.
self.assertNotAllClose(output, unmasked_output) self.assertNotAllClose(output, unmasked_output)
......
...@@ -23,7 +23,6 @@ import gin ...@@ -23,7 +23,6 @@ import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import multi_channel_attention from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers.util import tf_function_if_eager from official.nlp.modeling.layers.util import tf_function_if_eager
...@@ -106,21 +105,24 @@ class Transformer(tf.keras.layers.Layer): ...@@ -106,21 +105,24 @@ class Transformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention " "The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
name="self_attention") self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
**common_kwargs)
# pylint: disable=protected-access # pylint: disable=protected-access
self._attention_layer.build([input_tensor_shape] * 3) # Temporarily handling for checkpoint compatible changes.
self._attention_layer._build_from_signature(
query=input_tensor_shape, value=input_tensor_shape)
self._attention_output_dense = self._attention_layer._output_dense self._attention_output_dense = self._attention_layer._output_dense
# pylint: enable=protected-access # pylint: enable=protected-access
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
...@@ -132,17 +134,12 @@ class Transformer(tf.keras.layers.Layer): ...@@ -132,17 +134,12 @@ class Transformer(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=self._intermediate_size, "abc,cd->abd",
activation=None, output_shape=(None, self._intermediate_size),
kernel_initializer=self._kernel_initializer, bias_axes="d",
bias_initializer=self._bias_initializer, name="intermediate",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge # bfloat16 causes BERT with the LAMB optimizer to not converge
...@@ -151,16 +148,12 @@ class Transformer(tf.keras.layers.Layer): ...@@ -151,16 +148,12 @@ class Transformer(tf.keras.layers.Layer):
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._intermediate_activation, dtype=policy)
self._output_dense = dense_einsum.DenseEinsum( self._output_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=hidden_size, "abc,cd->abd",
kernel_initializer=self._kernel_initializer, output_shape=(None, hidden_size),
bias_initializer=self._bias_initializer, bias_axes="d",
kernel_regularizer=self._kernel_regularizer, name="output",
bias_regularizer=self._bias_regularizer, **common_kwargs)
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
...@@ -211,9 +204,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -211,9 +204,9 @@ class Transformer(tf.keras.layers.Layer):
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:self._output_range, :]
else: else:
target_tensor = input_tensor target_tensor = input_tensor
attention_inputs = [target_tensor, input_tensor]
attention_output = self._attention_layer(attention_inputs, attention_mask) attention_output = self._attention_layer(
query=target_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(target_tensor + attention_output = self._attention_layer_norm(target_tensor +
attention_output) attention_output)
...@@ -312,30 +305,27 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -312,30 +305,27 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self.num_attention_heads)) "heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size / self.num_attention_heads) self.attention_head_size = int(hidden_size / self.num_attention_heads)
# Self attention. common_kwargs = dict(
self.self_attention = attention.CachedAttention(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
name="self_attention") # Self attention.
self.self_attention_output_dense = dense_einsum.DenseEinsum( self.self_attention = attention.CachedAttention(
output_shape=hidden_size, num_heads=self.num_attention_heads,
num_summed_dimensions=2, key_size=self.attention_head_size,
kernel_initializer=self._kernel_initializer, dropout=self.attention_dropout_rate,
bias_initializer=self._bias_initializer, name="self_attention",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer, self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
activity_regularizer=self._activity_regularizer, "abc,cd->abd",
kernel_constraint=self._kernel_constraint, output_shape=(None, hidden_size),
bias_constraint=self._bias_constraint, bias_axes="d",
name="self_attention_output") name="output",
**common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout( self.self_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate) rate=self.dropout_rate)
self.self_attention_layer_norm = ( self.self_attention_layer_norm = (
...@@ -347,14 +337,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -347,14 +337,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
output_shape=hidden_size, output_shape=hidden_size,
kernel_initializer=self._kernel_initializer, name="attention/encdec",
bias_initializer=self._bias_initializer, **common_kwargs)
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention/encdec")
self.encdec_attention_dropout = tf.keras.layers.Dropout( self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate) rate=self.dropout_rate)
...@@ -363,29 +347,20 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -363,29 +347,20 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12)) name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))
# Feed-forward projection. # Feed-forward projection.
self.intermediate_dense = dense_einsum.DenseEinsum( self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=self.intermediate_size, "abc,cd->abd",
activation=None, output_shape=(None, self.intermediate_size),
kernel_initializer=self._kernel_initializer, bias_axes="d",
bias_initializer=self._bias_initializer, name="intermediate",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
self.intermediate_activation_layer = tf.keras.layers.Activation( self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation) self.intermediate_activation)
self.output_dense = dense_einsum.DenseEinsum( self.output_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=hidden_size, "abc,cd->abd",
kernel_initializer=self._kernel_initializer, output_shape=(None, hidden_size),
bias_initializer=self._bias_initializer, bias_axes="d",
kernel_regularizer=self._kernel_regularizer, name="output",
bias_regularizer=self._bias_regularizer, **common_kwargs)
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12) name="output_layer_norm", axis=-1, epsilon=1e-12)
...@@ -409,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -409,21 +384,23 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"TransformerDecoderLayer must have 4 inputs, but it got: %d" % "TransformerDecoderLayer must have 4 inputs, but it got: %d" %
len(inputs)) len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4] input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
self_attention_inputs = [input_tensor, input_tensor]
self_attention_output, cache = self.self_attention( self_attention_output, cache = self.self_attention(
self_attention_inputs, query=input_tensor,
value=input_tensor,
attention_mask=self_attention_mask, attention_mask=self_attention_mask,
cache=cache, cache=cache,
decode_loop_step=decode_loop_step) decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output) self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm( self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output) input_tensor + self_attention_output)
cross_attn_inputs = dict(
cross_attn_inputs = [self_attention_output, memory] query=self_attention_output,
value=memory,
attention_mask=attention_mask)
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
# Accesses the 5-th input tensor for the doc-attention probabilities. # Accesses the 5-th input tensor for the doc-attention probabilities.
cross_attn_inputs.append(inputs[-1]) cross_attn_inputs["context_attention_weights"] = inputs[-1]
attention_output = self.encdec_attention(cross_attn_inputs, attention_mask) attention_output = self.encdec_attention(**cross_attn_inputs)
attention_output = self.encdec_attention_dropout(attention_output) attention_output = self.encdec_attention_dropout(attention_output)
attention_output = self.encdec_attention_layer_norm(self_attention_output + attention_output = self.encdec_attention_layer_norm(self_attention_output +
attention_output) attention_output)
......
...@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -262,9 +262,8 @@ class TransformerScaffold(tf.keras.layers.Layer):
else: else:
input_tensor, attention_mask = (inputs, None) input_tensor, attention_mask = (inputs, None)
attention_inputs = [input_tensor, input_tensor] attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask)
attention_output = self._attention_layer(attention_inputs, attention_mask)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
attention_output = self._attention_layer_norm(input_tensor + attention_output = self._attention_layer_norm(input_tensor +
attention_output) attention_output)
......
...@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention): ...@@ -39,10 +39,10 @@ class ValidatedAttentionLayer(attention.MultiHeadAttention):
super(ValidatedAttentionLayer, self).__init__(**kwargs) super(ValidatedAttentionLayer, self).__init__(**kwargs)
self.list = call_list self.list = call_list
def call(self, inputs, attention_mask=None): def call(self, query, value, attention_mask=None):
self.list.append(True) self.list.append(True)
return super(ValidatedAttentionLayer, self).call( return super(ValidatedAttentionLayer, self).call(
inputs, attention_mask=attention_mask) query, value, attention_mask=attention_mask)
def get_config(self): def get_config(self):
config = super(ValidatedAttentionLayer, self).get_config() config = super(ValidatedAttentionLayer, self).get_config()
......
...@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -152,7 +152,10 @@ class TransformerLayerTest(keras_parameterized.TestCase):
_ = new_layer([input_data, mask_data]) _ = new_layer([input_data, mask_data])
new_layer.set_weights(test_layer.get_weights()) new_layer.set_weights(test_layer.get_weights())
new_output_tensor = new_layer([input_data, mask_data]) new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :]) self.assertAllClose(new_output_tensor,
output_tensor[:, 0:1, :],
atol=5e-5,
rtol=0.003)
def test_layer_invocation_with_float16_dtype(self, transformer_cls): def test_layer_invocation_with_float16_dtype(self, transformer_cls):
tf.keras.mixed_precision.experimental.set_policy('mixed_float16') tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
......
...@@ -4,6 +4,3 @@ Losses contains common loss computation used in NLP tasks. ...@@ -4,6 +4,3 @@ Losses contains common loss computation used in NLP tasks.
* `weighted_sparse_categorical_crossentropy_loss` computes per-batch sparse * `weighted_sparse_categorical_crossentropy_loss` computes per-batch sparse
categorical crossentropy loss. categorical crossentropy loss.
* `weighted_sparse_categorical_crossentropy_per_example_loss` computes
per-example sparse categorical crossentropy loss.
...@@ -14,4 +14,3 @@ ...@@ -14,4 +14,3 @@
# ============================================================================== # ==============================================================================
"""Activations package definition. Subject to change.""" """Activations package definition. Subject to change."""
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import loss as weighted_sparse_categorical_crossentropy_loss from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import loss as weighted_sparse_categorical_crossentropy_loss
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import per_example_loss as weighted_sparse_categorical_crossentropy_per_example_loss
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Sparse categorical cross-entropy losses.""" """Weighted sparse categorical cross-entropy losses."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -43,37 +43,7 @@ def _validate_rank(labels, predictions, weights): ...@@ -43,37 +43,7 @@ def _validate_rank(labels, predictions, weights):
"predictions.shape was %s.") % (labels.shape, predictions.shape)) "predictions.shape was %s.") % (labels.shape, predictions.shape))
def per_example_loss(labels, predictions, weights=None): def loss(labels, predictions, weights=None, from_logits=False):
"""Calculate a per-example sparse categorical crossentropy loss.
This loss function assumes that the predictions are post-softmax.
Args:
labels: The labels to evaluate against. Should be a set of integer indices
ranging from 0 to (vocab_size-1).
predictions: The network predictions. Should have softmax already applied.
weights: An optional weight array of the same shape as the 'labels' array.
If None, all examples will be used.
Returns:
A tensor of shape predictions.shape[:-1] containing the per-example
loss.
"""
# When using these functions with the Keras core API, we will need to squeeze
# the labels tensor - Keras adds a spurious inner dimension.
labels, predictions = _adjust_labels(labels, predictions)
_validate_rank(labels, predictions, weights)
labels_one_hot = tf.one_hot(labels, predictions.shape[-1])
labels_one_hot = tf.cast(labels_one_hot, predictions.dtype)
per_example_loss_data = -tf.reduce_sum(
predictions * labels_one_hot, axis=[-1])
if weights is not None:
weights = tf.cast(weights, per_example_loss_data.dtype)
per_example_loss_data = weights * per_example_loss_data
return per_example_loss_data
def loss(labels, predictions, weights=None):
"""Calculate a per-batch sparse categorical crossentropy loss. """Calculate a per-batch sparse categorical crossentropy loss.
This loss function assumes that the predictions are post-softmax. This loss function assumes that the predictions are post-softmax.
...@@ -83,6 +53,7 @@ def loss(labels, predictions, weights=None): ...@@ -83,6 +53,7 @@ def loss(labels, predictions, weights=None):
predictions: The network predictions. Should have softmax already applied. predictions: The network predictions. Should have softmax already applied.
weights: An optional weight array of the same shape as the 'labels' array. weights: An optional weight array of the same shape as the 'labels' array.
If None, all examples will be used. If None, all examples will be used.
from_logits: Whether the input predictions are logits.
Returns: Returns:
A loss scalar. A loss scalar.
...@@ -95,12 +66,11 @@ def loss(labels, predictions, weights=None): ...@@ -95,12 +66,11 @@ def loss(labels, predictions, weights=None):
labels, predictions = _adjust_labels(labels, predictions) labels, predictions = _adjust_labels(labels, predictions)
_validate_rank(labels, predictions, weights) _validate_rank(labels, predictions, weights)
per_example_loss_data = per_example_loss(labels, predictions, weights) example_losses = tf.keras.losses.sparse_categorical_crossentropy(
labels, predictions, from_logits=from_logits)
if weights is None: if weights is None:
return tf.reduce_mean(per_example_loss_data) return tf.reduce_mean(example_losses)
else: weights = tf.cast(weights, predictions.dtype)
numerator = tf.reduce_sum(per_example_loss_data) return tf.math.divide_no_nan(
weights = tf.cast(weights, predictions.dtype) tf.reduce_sum(example_losses * weights), tf.reduce_sum(weights))
denominator = tf.reduce_sum(weights) + 1e-5
return numerator / denominator
...@@ -53,8 +53,7 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -53,8 +53,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# Create a maskedLM from the transformer stack. # Create a maskedLM from the transformer stack.
test_layer = layers.MaskedLM( test_layer = layers.MaskedLM(
embedding_table=xformer_stack.get_embedding_table(), embedding_table=xformer_stack.get_embedding_table(), output=output)
output=output)
# Create a model from the masked LM layer. # Create a model from the masked LM layer.
lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size)) lm_input_tensor = tf.keras.Input(shape=(sequence_length, hidden_size))
...@@ -63,123 +62,6 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -63,123 +62,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
output = test_layer(lm_input_tensor, masked_positions=masked_lm_positions) output = test_layer(lm_input_tensor, masked_positions=masked_lm_positions)
return tf.keras.Model([lm_input_tensor, masked_lm_positions], output) return tf.keras.Model([lm_input_tensor, masked_lm_positions], output)
def create_classification_model(self, input_width, num_classes):
test_object = networks.Classification(
input_width=input_width, num_classes=num_classes)
# Create a 2-dimensional input (the first dimension is implicit).
pooled_data = tf.keras.Input(shape=(input_width,), dtype=tf.float32)
output = test_object(pooled_data)
return tf.keras.Model(pooled_data, output)
def test_per_example_loss_3d_input(self):
"""Test per-example loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
model = self.create_lm_model(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Get the output of the masked LM.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
output_data = model.predict([lm_input_data, masked_position_data])
# Calculate per-example loss.
labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
# Per-example loss data should have one value per prediction, and those
# values shouldn't be zero in this case (as we're using random data).
expected_shape = [batch_size, num_predictions]
self.assertEqual(expected_shape, per_example_loss_data.shape.as_list())
self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_per_example_loss_2d_input(self):
"""Test per-example loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per example loss.
labels = np.random.randint(num_classes, size=(batch_size))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
# Per-example loss data should have one value per batch item, and those
# values shouldn't be zero in this case (as we're using random data).
self.assertEqual([batch_size], per_example_loss_data.shape.as_list())
self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_per_example_loss_weights_3d_input(self):
"""Test weighted per-example loss with a 3-d input, from a masked LM."""
vocab_size = 100
sequence_length = 32
hidden_size = 64
num_predictions = 21
model = self.create_lm_model(
vocab_size=vocab_size,
sequence_length=sequence_length,
hidden_size=hidden_size,
num_predictions=num_predictions)
# Get the output of the masked LM.
batch_size = 3
lm_input_data = 10 * np.random.random_sample(
(batch_size, sequence_length, hidden_size))
masked_position_data = np.random.randint(
2, size=(batch_size, num_predictions))
output_data = model.predict([lm_input_data, masked_position_data])
# Calculate per-example loss with weights.
labels = np.random.randint(vocab_size, size=(batch_size, num_predictions))
weights = np.random.randint(2, size=(batch_size, num_predictions))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
# Weighted per-example loss data should be equivalent to multiplying the
# loss tensor by the weights tensor.
expected_weighted_loss = per_example_loss_data * weights
self.assertAllClose(expected_weighted_loss, per_example_loss_data)
def test_per_example_loss_weights_2d_input(self):
"""Test weighted per-example loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per-example loss with weights.
labels = np.random.randint(num_classes, size=(batch_size))
weights = np.random.randint(2, size=(batch_size))
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
# Weighted per-example loss data should be equivalent to multiplying the
# loss tensor by the weights tensor.
expected_weighted_loss = per_example_loss_data * weights
self.assertAllClose(expected_weighted_loss, per_example_loss_data)
def test_loss_3d_input(self): def test_loss_3d_input(self):
"""Test overall loss with a 3-dimensional input, from a masked LM.""" """Test overall loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100 vocab_size = 100
...@@ -213,26 +95,6 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -213,26 +95,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
self.assertNotAllClose( self.assertNotAllClose(
tf.zeros_like(per_example_loss_data), per_example_loss_data) tf.zeros_like(per_example_loss_data), per_example_loss_data)
def test_loss_2d_input(self):
"""Test overall loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate per example loss.
labels = np.random.randint(num_classes, size=(batch_size))
loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels)
# Loss data should have one value only, and that value shouldn't be zero in
# this case (as we're using random data).
self.assertNotAllClose(0, loss_data)
def test_loss_weights_3d_input(self): def test_loss_weights_3d_input(self):
"""Test masked loss with a 3-dimensional input, from a masked LM.""" """Test masked loss with a 3-dimensional input, from a masked LM."""
vocab_size = 100 vocab_size = 100
...@@ -262,26 +124,6 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -262,26 +124,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# Because the tensor is fully masked, the loss should be 0. # Because the tensor is fully masked, the loss should be 0.
self.assertAllClose(0, weighted_loss_data) self.assertAllClose(0, weighted_loss_data)
def test_loss_weights_2d_input(self):
"""Test masked loss with a 2-d input, from a classifier."""
input_width = 512
num_classes = 10
model = self.create_classification_model(input_width, num_classes)
# Invoke the network as part of a Model.
batch_size = 3
input_data = 10 * np.random.random_sample((batch_size, input_width))
output_data = model.predict(input_data)
# Calculate a fully masked weight tensor. This should give a loss of zero.
labels = np.random.randint(num_classes, size=(batch_size))
null_weights = np.zeros((batch_size))
weighted_loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=null_weights)
# Because the tensor is fully masked, the loss should be 0.
self.assertAllClose(0, weighted_loss_data)
def test_mismatched_predictions_and_labels_ranks_squeezes(self): def test_mismatched_predictions_and_labels_ranks_squeezes(self):
"""Test that the loss asserts when rank(predictions)-1 != rank(labels).""" """Test that the loss asserts when rank(predictions)-1 != rank(labels)."""
batch_size = 3 batch_size = 3
...@@ -289,7 +131,7 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -289,7 +131,7 @@ class ClassificationLossTest(keras_parameterized.TestCase):
labels = np.random.randint(10, size=(batch_size, 1)) labels = np.random.randint(10, size=(batch_size, 1))
# All that this test tests is that the squeeze is successful. # All that this test tests is that the squeeze is successful.
_ = weighted_sparse_categorical_crossentropy.per_example_loss( _ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels) predictions=output_data, labels=labels)
def test_mismatched_weights_and_labels_ranks_fail(self): def test_mismatched_weights_and_labels_ranks_fail(self):
...@@ -299,9 +141,6 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -299,9 +141,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
labels = np.random.randint(10, size=(batch_size, 10)) labels = np.random.randint(10, size=(batch_size, 10))
weights = np.random.randint(2, size=(batch_size)) weights = np.random.randint(2, size=(batch_size))
with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
_ = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"): with self.assertRaisesRegex(RuntimeError, ".*of the same rank.*"):
_ = weighted_sparse_categorical_crossentropy.loss( _ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights) predictions=output_data, labels=labels, weights=weights)
...@@ -317,8 +156,6 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -317,8 +156,6 @@ class ClassificationLossTest(keras_parameterized.TestCase):
# We're not trying to validate numerical correctness, just ensure that # We're not trying to validate numerical correctness, just ensure that
# we can in fact pass tensors to these functions without causing runtime # we can in fact pass tensors to these functions without causing runtime
# errors from the shape checking code. # errors from the shape checking code.
_ = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels, weights=weights)
_ = weighted_sparse_categorical_crossentropy.loss( _ = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights) predictions=output_data, labels=labels, weights=weights)
...@@ -338,20 +175,15 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -338,20 +175,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
[-2.7760355, -1.8219438, -3.0924666, -1.0779881, -0.9407509]]]) [-2.7760355, -1.8219438, -3.0924666, -1.0779881, -0.9407509]]])
labels = np.array([[4, 0], [2, 2], [2, 1]]) labels = np.array([[4, 0], [2, 2], [2, 1]])
# Validate that per_example loss calculations are the same.
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
expected_per_example_loss_data = [[1.2923571, 2.7117882],
[2.287932, 2.287932],
[3.0924666, 1.8219438]]
self.assertAllClose(expected_per_example_loss_data, per_example_loss_data)
# Validate that overall loss calculations are the same. # Validate that overall loss calculations are the same.
weights = np.array([[1, 0], [0, 0], [0, 0]]) weights = np.array([[1, 0], [0, 0], [0, 0]])
loss_data = weighted_sparse_categorical_crossentropy.loss( loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights) predictions=output_data,
labels=labels,
weights=weights,
from_logits=True)
expected_loss_data = 1.2923441 expected_loss_data = 1.2923441
self.assertAllClose(expected_loss_data, loss_data) self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3)
def test_legacy_classification_loss_compatibility(self): def test_legacy_classification_loss_compatibility(self):
"""Test to validate computational correctness during refactors.""" """Test to validate computational correctness during refactors."""
...@@ -362,19 +194,15 @@ class ClassificationLossTest(keras_parameterized.TestCase): ...@@ -362,19 +194,15 @@ class ClassificationLossTest(keras_parameterized.TestCase):
[-1.6975292e-03, -6.4009643e+00, -1.0226612e+01]]) [-1.6975292e-03, -6.4009643e+00, -1.0226612e+01]])
labels = np.array([2, 1]) labels = np.array([2, 1])
# Validate that per_example loss calculations are the same.
per_example_loss_data = weighted_sparse_categorical_crossentropy.per_example_loss(
predictions=output_data, labels=labels)
expected_per_example_loss_data = [6.4434357, 6.4009643]
self.assertAllClose(expected_per_example_loss_data, per_example_loss_data)
# Validate that overall loss calculations are the same. # Validate that overall loss calculations are the same.
weights = None weights = None
loss_data = weighted_sparse_categorical_crossentropy.loss( loss_data = weighted_sparse_categorical_crossentropy.loss(
predictions=output_data, labels=labels, weights=weights) predictions=output_data,
labels=labels,
weights=weights,
from_logits=True)
expected_loss_data = 6.4222 expected_loss_data = 6.4222
self.assertAllClose(expected_loss_data, loss_data) self.assertAllClose(expected_loss_data, loss_data, rtol=1e-3)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()
...@@ -17,3 +17,4 @@ from official.nlp.modeling.models.bert_classifier import BertClassifier ...@@ -17,3 +17,4 @@ from official.nlp.modeling.models.bert_classifier import BertClassifier
from official.nlp.modeling.models.bert_pretrainer import BertPretrainer from official.nlp.modeling.models.bert_pretrainer import BertPretrainer
from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler from official.nlp.modeling.models.bert_span_labeler import BertSpanLabeler
from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier from official.nlp.modeling.models.bert_token_classifier import BertTokenClassifier
from official.nlp.modeling.models.electra_pretrainer import ElectraPretrainer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment