Commit 43f5340f authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Migrate all DenseEinsum to tf.keras.experimental.EinsumDense

PiperOrigin-RevId: 320240466
parent 180f2607
...@@ -3,11 +3,6 @@ ...@@ -3,11 +3,6 @@
Layers are the fundamental building blocks for NLP models. They can be used to Layers are the fundamental building blocks for NLP models. They can be used to
assemble new layers, networks, or models. assemble new layers, networks, or models.
* [DenseEinsum](dense_einsum.py) implements a feedforward network using
tf.einsum. This layer contains the einsum op, the associated weight, and the
logic required to generate the einsum expression for the given
initialization parameters.
* [MultiHeadAttention](attention.py) implements an optionally masked attention * [MultiHeadAttention](attention.py) implements an optionally masked attention
between query, key, value tensors as described in between query, key, value tensors as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If
......
...@@ -21,6 +21,8 @@ from __future__ import print_function ...@@ -21,6 +21,8 @@ from __future__ import print_function
import tensorflow as tf import tensorflow as tf
from tensorflow.python.util import deprecation
_CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"] _CHR_IDX = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m"]
...@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer): ...@@ -57,6 +59,9 @@ class DenseEinsum(tf.keras.layers.Layer):
`(batch_size, units)`. `(batch_size, units)`.
""" """
@deprecation.deprecated(
None, "DenseEinsum is deprecated. Please use "
"tf.keras.experimental.EinsumDense layer instead.")
def __init__(self, def __init__(self,
output_shape, output_shape,
num_summed_dimensions=1, num_summed_dimensions=1,
......
...@@ -26,7 +26,6 @@ import math ...@@ -26,7 +26,6 @@ import math
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import masked_softmax from official.nlp.modeling.layers import masked_softmax
...@@ -67,28 +66,26 @@ class VotingAttention(tf.keras.layers.Layer): ...@@ -67,28 +66,26 @@ class VotingAttention(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
self._query_dense = dense_einsum.DenseEinsum( common_kwargs = dict(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
dtype=self.dtype, self._query_dense = tf.keras.layers.experimental.EinsumDense(
name="encdocatt_query") "BAE,ENH->BANH",
self._key_dense = dense_einsum.DenseEinsum( output_shape=(None, self._num_heads, self._head_size),
output_shape=(self._num_heads, self._head_size), bias_axes="NH",
kernel_initializer=self._kernel_initializer, name="query",
bias_initializer=self._bias_initializer, **common_kwargs)
kernel_regularizer=self._kernel_regularizer, self._key_dense = tf.keras.layers.experimental.EinsumDense(
bias_regularizer=self._bias_regularizer, "BAE,ENH->BANH",
activity_regularizer=self._activity_regularizer, output_shape=(None, self._num_heads, self._head_size),
kernel_constraint=self._kernel_constraint, bias_axes="NH",
bias_constraint=self._bias_constraint, name="key",
dtype=self.dtype, **common_kwargs)
name="encdocatt_key")
super(VotingAttention, self).build(unused_input_shapes) super(VotingAttention, self).build(unused_input_shapes)
def call(self, encoder_outputs, doc_attention_mask): def call(self, encoder_outputs, doc_attention_mask):
......
...@@ -23,7 +23,6 @@ import gin ...@@ -23,7 +23,6 @@ import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
...@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -109,19 +108,20 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention " "The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
name="self_attention") self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
...@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -132,17 +132,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=self._intermediate_size, "abc,cd->abd",
activation=None, output_shape=(None, self._intermediate_size),
kernel_initializer=self._kernel_initializer, bias_axes="d",
bias_initializer=self._bias_initializer, name="intermediate",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge # bfloat16 causes BERT with the LAMB optimizer to not converge
...@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -151,16 +146,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._intermediate_activation, dtype=policy)
self._output_dense = dense_einsum.DenseEinsum( self._output_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=hidden_size, "abc,cd->abd",
kernel_initializer=self._kernel_initializer, output_shape=(None, hidden_size),
bias_initializer=self._bias_initializer, bias_axes="d",
kernel_regularizer=self._kernel_regularizer, name="output",
bias_regularizer=self._bias_regularizer, **common_kwargs)
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
......
...@@ -23,7 +23,6 @@ import gin ...@@ -23,7 +23,6 @@ import gin
import tensorflow as tf import tensorflow as tf
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import multi_channel_attention from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers.util import tf_function_if_eager from official.nlp.modeling.layers.util import tf_function_if_eager
...@@ -106,19 +105,20 @@ class Transformer(tf.keras.layers.Layer): ...@@ -106,19 +105,20 @@ class Transformer(tf.keras.layers.Layer):
"The input size (%d) is not a multiple of the number of attention " "The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
name="self_attention") self._attention_layer = attention.MultiHeadAttention(
num_heads=self._num_heads,
key_size=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
**common_kwargs)
# pylint: disable=protected-access # pylint: disable=protected-access
self._attention_layer.build([input_tensor_shape] * 3) self._attention_layer.build([input_tensor_shape] * 3)
self._attention_output_dense = self._attention_layer._output_dense self._attention_output_dense = self._attention_layer._output_dense
...@@ -132,17 +132,12 @@ class Transformer(tf.keras.layers.Layer): ...@@ -132,17 +132,12 @@ class Transformer(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = dense_einsum.DenseEinsum( self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=self._intermediate_size, "abc,cd->abd",
activation=None, output_shape=(None, self._intermediate_size),
kernel_initializer=self._kernel_initializer, bias_axes="d",
bias_initializer=self._bias_initializer, name="intermediate",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
policy = tf.keras.mixed_precision.experimental.global_policy() policy = tf.keras.mixed_precision.experimental.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge # bfloat16 causes BERT with the LAMB optimizer to not converge
...@@ -151,16 +146,12 @@ class Transformer(tf.keras.layers.Layer): ...@@ -151,16 +146,12 @@ class Transformer(tf.keras.layers.Layer):
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._intermediate_activation, dtype=policy)
self._output_dense = dense_einsum.DenseEinsum( self._output_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=hidden_size, "abc,cd->abd",
kernel_initializer=self._kernel_initializer, output_shape=(None, hidden_size),
bias_initializer=self._bias_initializer, bias_axes="d",
kernel_regularizer=self._kernel_regularizer, name="output",
bias_regularizer=self._bias_regularizer, **common_kwargs)
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
...@@ -312,30 +303,27 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -312,30 +303,27 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
"The hidden size (%d) is not a multiple of the number of attention " "The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self.num_attention_heads)) "heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size / self.num_attention_heads) self.attention_head_size = int(hidden_size / self.num_attention_heads)
# Self attention. common_kwargs = dict(
self.self_attention = attention.CachedAttention(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_dropout_rate,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention")
self.self_attention_output_dense = dense_einsum.DenseEinsum(
output_shape=hidden_size,
num_summed_dimensions=2,
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint)
name="self_attention_output") # Self attention.
self.self_attention = attention.CachedAttention(
num_heads=self.num_attention_heads,
key_size=self.attention_head_size,
dropout=self.attention_dropout_rate,
name="self_attention",
**common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
**common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout( self.self_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate) rate=self.dropout_rate)
self.self_attention_layer_norm = ( self.self_attention_layer_norm = (
...@@ -347,14 +335,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -347,14 +335,8 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
key_size=self.attention_head_size, key_size=self.attention_head_size,
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
output_shape=hidden_size, output_shape=hidden_size,
kernel_initializer=self._kernel_initializer, name="attention/encdec",
bias_initializer=self._bias_initializer, **common_kwargs)
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention/encdec")
self.encdec_attention_dropout = tf.keras.layers.Dropout( self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.dropout_rate) rate=self.dropout_rate)
...@@ -363,29 +345,20 @@ class TransformerDecoderLayer(tf.keras.layers.Layer): ...@@ -363,29 +345,20 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12)) name="attention/encdec_output_layer_norm", axis=-1, epsilon=1e-12))
# Feed-forward projection. # Feed-forward projection.
self.intermediate_dense = dense_einsum.DenseEinsum( self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=self.intermediate_size, "abc,cd->abd",
activation=None, output_shape=(None, self.intermediate_size),
kernel_initializer=self._kernel_initializer, bias_axes="d",
bias_initializer=self._bias_initializer, name="intermediate",
kernel_regularizer=self._kernel_regularizer, **common_kwargs)
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="intermediate")
self.intermediate_activation_layer = tf.keras.layers.Activation( self.intermediate_activation_layer = tf.keras.layers.Activation(
self.intermediate_activation) self.intermediate_activation)
self.output_dense = dense_einsum.DenseEinsum( self.output_dense = tf.keras.layers.experimental.EinsumDense(
output_shape=hidden_size, "abc,cd->abd",
kernel_initializer=self._kernel_initializer, output_shape=(None, hidden_size),
bias_initializer=self._bias_initializer, bias_axes="d",
kernel_regularizer=self._kernel_regularizer, name="output",
bias_regularizer=self._bias_regularizer, **common_kwargs)
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="output")
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", axis=-1, epsilon=1e-12) name="output_layer_norm", axis=-1, epsilon=1e-12)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment