Commit a506d3a4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Migrate to tf.keras.layers.EinsumDense after TF 2.9

PiperOrigin-RevId: 453749903
parent 1738af44
...@@ -52,19 +52,19 @@ class Attention(tf.keras.layers.Layer): ...@@ -52,19 +52,19 @@ class Attention(tf.keras.layers.Layer):
attention_initializer = _glorot_initializer(input_shape.as_list()[-1], attention_initializer = _glorot_initializer(input_shape.as_list()[-1],
self.hidden_size) self.hidden_size)
self.query_dense_layer = tf.keras.layers.experimental.EinsumDense( self.query_dense_layer = tf.keras.layers.EinsumDense(
"BTE,ENH->BTNH", "BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head), output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=tf_utils.clone_initializer(attention_initializer), kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None, bias_axes=None,
name="query") name="query")
self.key_dense_layer = tf.keras.layers.experimental.EinsumDense( self.key_dense_layer = tf.keras.layers.EinsumDense(
"BTE,ENH->BTNH", "BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head), output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=tf_utils.clone_initializer(attention_initializer), kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None, bias_axes=None,
name="key") name="key")
self.value_dense_layer = tf.keras.layers.experimental.EinsumDense( self.value_dense_layer = tf.keras.layers.EinsumDense(
"BTE,ENH->BTNH", "BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head), output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=tf_utils.clone_initializer(attention_initializer), kernel_initializer=tf_utils.clone_initializer(attention_initializer),
...@@ -72,7 +72,7 @@ class Attention(tf.keras.layers.Layer): ...@@ -72,7 +72,7 @@ class Attention(tf.keras.layers.Layer):
name="value") name="value")
output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size) output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size)
self.output_dense_layer = tf.keras.layers.experimental.EinsumDense( self.output_dense_layer = tf.keras.layers.EinsumDense(
"BTNH,NHE->BTE", "BTNH,NHE->BTE",
output_shape=(None, self.hidden_size), output_shape=(None, self.hidden_size),
kernel_initializer=output_initializer, kernel_initializer=output_initializer,
......
...@@ -18,7 +18,7 @@ import math ...@@ -18,7 +18,7 @@ import math
import tensorflow as tf import tensorflow as tf
EinsumDense = tf.keras.layers.experimental.EinsumDense EinsumDense = tf.keras.layers.EinsumDense
MultiHeadAttention = tf.keras.layers.MultiHeadAttention MultiHeadAttention = tf.keras.layers.MultiHeadAttention
......
...@@ -88,7 +88,7 @@ class BlockDiagFeedforward(tf.keras.layers.Layer): ...@@ -88,7 +88,7 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint) bias_constraint=self._bias_constraint)
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cde->abde", "abc,cde->abde",
output_shape=(None, self._num_blocks, output_shape=(None, self._num_blocks,
self._intermediate_size // self._num_blocks), self._intermediate_size // self._num_blocks),
...@@ -106,10 +106,9 @@ class BlockDiagFeedforward(tf.keras.layers.Layer): ...@@ -106,10 +106,9 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._intermediate_activation, dtype=policy)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
"abde,deo->abdo", "abde,deo->abdo",
output_shape=(None, self._num_blocks, output_shape=(None, self._num_blocks, hidden_size // self._num_blocks),
hidden_size // self._num_blocks),
bias_axes="do", bias_axes="do",
name="output", name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
...@@ -117,7 +116,7 @@ class BlockDiagFeedforward(tf.keras.layers.Layer): ...@@ -117,7 +116,7 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
**common_kwargs) **common_kwargs)
if self._apply_mixing: if self._apply_mixing:
self._output_mixing = tf.keras.layers.experimental.EinsumDense( self._output_mixing = tf.keras.layers.EinsumDense(
"abdo,de->abeo", "abdo,de->abeo",
output_shape=(None, self._num_blocks, output_shape=(None, self._num_blocks,
hidden_size // self._num_blocks), hidden_size // self._num_blocks),
......
...@@ -116,7 +116,7 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -116,7 +116,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
activation_policy = tf.float32 activation_policy = tf.float32
for i in range(self._num_blocks): for i in range(self._num_blocks):
self._intermediate_dense.append( self._intermediate_dense.append(
tf.keras.layers.experimental.EinsumDense( tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._intermediate_size), output_shape=(None, self._intermediate_size),
bias_axes="d", bias_axes="d",
...@@ -131,7 +131,7 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -131,7 +131,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
self._intermediate_activation, dtype=activation_policy)) self._intermediate_activation, dtype=activation_policy))
if self._use_gate: if self._use_gate:
self._gate_dense.append( self._gate_dense.append(
tf.keras.layers.experimental.EinsumDense( tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._intermediate_size), output_shape=(None, self._intermediate_size),
bias_axes="d", bias_axes="d",
...@@ -142,7 +142,7 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -142,7 +142,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
self._bias_initializer), self._bias_initializer),
**common_kwargs)) **common_kwargs))
self._output_dense.append( self._output_dense.append(
tf.keras.layers.experimental.EinsumDense( tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
......
...@@ -122,7 +122,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer): ...@@ -122,7 +122,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
max_length=max_sequence_length, max_length=max_sequence_length,
initializer=tf_utils.clone_initializer(self.initializer), initializer=tf_utils.clone_initializer(self.initializer),
name='position_embedding') name='position_embedding')
self.word_embedding_proj = tf.keras.layers.experimental.EinsumDense( self.word_embedding_proj = tf.keras.layers.EinsumDense(
'abc,cd->abd', 'abc,cd->abd',
output_shape=[None, self.output_embed_size], output_shape=[None, self.output_embed_size],
kernel_initializer=tf_utils.clone_initializer(self.initializer), kernel_initializer=tf_utils.clone_initializer(self.initializer),
...@@ -244,7 +244,7 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -244,7 +244,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
self.block_layers = {} self.block_layers = {}
# add input bottleneck # add input bottleneck
dense_layer_2d = tf.keras.layers.experimental.EinsumDense( dense_layer_2d = tf.keras.layers.EinsumDense(
'abc,cd->abd', 'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size], output_shape=[None, self.intra_bottleneck_size],
bias_axes='d', bias_axes='d',
...@@ -256,7 +256,7 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -256,7 +256,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
layer_norm] layer_norm]
if self.key_query_shared_bottleneck: if self.key_query_shared_bottleneck:
dense_layer_2d = tf.keras.layers.experimental.EinsumDense( dense_layer_2d = tf.keras.layers.EinsumDense(
'abc,cd->abd', 'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size], output_shape=[None, self.intra_bottleneck_size],
bias_axes='d', bias_axes='d',
...@@ -286,7 +286,7 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -286,7 +286,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
for ffn_layer_idx in range(self.num_feedforward_networks): for ffn_layer_idx in range(self.num_feedforward_networks):
layer_prefix = f'ffn_layer_{ffn_layer_idx}' layer_prefix = f'ffn_layer_{ffn_layer_idx}'
layer_name = layer_prefix + '/intermediate_dense' layer_name = layer_prefix + '/intermediate_dense'
intermediate_layer = tf.keras.layers.experimental.EinsumDense( intermediate_layer = tf.keras.layers.EinsumDense(
'abc,cd->abd', 'abc,cd->abd',
activation=self.intermediate_act_fn, activation=self.intermediate_act_fn,
output_shape=[None, self.intermediate_size], output_shape=[None, self.intermediate_size],
...@@ -294,7 +294,7 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -294,7 +294,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
kernel_initializer=tf_utils.clone_initializer(self.initializer), kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=layer_name) name=layer_name)
layer_name = layer_prefix + '/output_dense' layer_name = layer_prefix + '/output_dense'
output_layer = tf.keras.layers.experimental.EinsumDense( output_layer = tf.keras.layers.EinsumDense(
'abc,cd->abd', 'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size], output_shape=[None, self.intra_bottleneck_size],
bias_axes='d', bias_axes='d',
...@@ -308,7 +308,7 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -308,7 +308,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
layer_norm]) layer_norm])
# add output bottleneck # add output bottleneck
bottleneck = tf.keras.layers.experimental.EinsumDense( bottleneck = tf.keras.layers.EinsumDense(
'abc,cd->abd', 'abc,cd->abd',
output_shape=[None, self.hidden_size], output_shape=[None, self.hidden_size],
activation=None, activation=None,
......
...@@ -66,7 +66,7 @@ class VotingAttention(tf.keras.layers.Layer): ...@@ -66,7 +66,7 @@ class VotingAttention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint) bias_constraint=self._bias_constraint)
self._query_dense = tf.keras.layers.experimental.EinsumDense( self._query_dense = tf.keras.layers.EinsumDense(
"BAE,ENH->BANH", "BAE,ENH->BANH",
output_shape=(None, self._num_heads, self._head_size), output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH", bias_axes="NH",
...@@ -74,7 +74,7 @@ class VotingAttention(tf.keras.layers.Layer): ...@@ -74,7 +74,7 @@ class VotingAttention(tf.keras.layers.Layer):
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer), kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._key_dense = tf.keras.layers.experimental.EinsumDense( self._key_dense = tf.keras.layers.EinsumDense(
"BAE,ENH->BANH", "BAE,ENH->BANH",
output_shape=(None, self._num_heads, self._head_size), output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH", bias_axes="NH",
......
...@@ -98,14 +98,14 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention): ...@@ -98,14 +98,14 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
`[B, L, dim]`. `[B, L, dim]`.
segment_matrix: Optional `Tensor` representing segmentation IDs used in segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet of shape `[B, S, S + M]`. XLNet of shape `[B, S, S + M]`.
segment_encoding: Optional `Tensor` representing the segmentation segment_encoding: Optional `Tensor` representing the segmentation encoding
encoding as used in XLNet of shape `[2, num_heads, dim]`. as used in XLNet of shape `[2, num_heads, dim]`.
segment_attention_bias: Optional trainable bias parameter added to the segment_attention_bias: Optional trainable bias parameter added to the query
query had when calculating the segment-based attention score used in had when calculating the segment-based attention score used in XLNet of
XLNet of shape `[num_heads, dim]`. shape `[num_heads, dim]`.
state: Optional `Tensor` of shape `[B, M, E]` where M is the length of the state: Optional `Tensor` of shape `[B, M, E]` where M is the length of the
state or memory. state or memory. If passed, this is also attended over as in Transformer
If passed, this is also attended over as in Transformer XL. XL.
attention_mask: A boolean mask of shape `[B, T, S]` that prevents attention attention_mask: A boolean mask of shape `[B, T, S]` that prevents attention
to certain positions. to certain positions.
""" """
...@@ -144,7 +144,7 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention): ...@@ -144,7 +144,7 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
with tf.init_scope(): with tf.init_scope():
einsum_equation, _, output_rank = _build_proj_equation( einsum_equation, _, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2) key_shape.rank - 1, bound_dims=1, output_dims=2)
self._encoding_dense = tf.keras.layers.experimental.EinsumDense( self._encoding_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]), [self._num_heads, self._key_dim]),
...@@ -255,8 +255,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention): ...@@ -255,8 +255,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
Args: Args:
query: attention input. query: attention input.
value: attention input. value: attention input.
content_attention_bias: A trainable bias parameter added to the query content_attention_bias: A trainable bias parameter added to the query head
head when calculating the content-based attention score. when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score. head when calculating the position-based attention score.
key: attention input. key: attention input.
...@@ -264,8 +264,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention): ...@@ -264,8 +264,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
value. value.
segment_matrix: Optional `Tensor` representing segmentation IDs used in segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet. XLNet.
segment_encoding: Optional `Tensor` representing the segmentation segment_encoding: Optional `Tensor` representing the segmentation encoding
encoding as used in XLNet. as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in query had when calculating the segment-based attention score used in
XLNet. XLNet.
...@@ -394,22 +394,22 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention): ...@@ -394,22 +394,22 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
content_stream: The content representation, commonly referred to as h. content_stream: The content representation, commonly referred to as h.
This serves a similar role to the standard hidden states in This serves a similar role to the standard hidden states in
Transformer-XL. Transformer-XL.
content_attention_bias: A trainable bias parameter added to the query content_attention_bias: A trainable bias parameter added to the query head
head when calculating the content-based attention score. when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score. head when calculating the position-based attention score.
query_stream: The query representation, commonly referred to as g. query_stream: The query representation, commonly referred to as g. This
This only has access to contextual information and position, but not only has access to contextual information and position, but not content.
content. If not provided, then this is MultiHeadRelativeAttention with If not provided, then this is MultiHeadRelativeAttention with
self-attention. self-attention.
relative_position_encoding: relative positional encoding for key and relative_position_encoding: relative positional encoding for key and
value. value.
target_mapping: Optional `Tensor` representing the target mapping used target_mapping: Optional `Tensor` representing the target mapping used in
in partial prediction. partial prediction.
segment_matrix: Optional `Tensor` representing segmentation IDs used in segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet. XLNet.
segment_encoding: Optional `Tensor` representing the segmentation segment_encoding: Optional `Tensor` representing the segmentation encoding
encoding as used in XLNet. as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the segment_attention_bias: Optional trainable bias parameter added to the
query head when calculating the segment-based attention score. query head when calculating the segment-based attention score.
state: (default None) optional state. If passed, this is also attended state: (default None) optional state. If passed, this is also attended
...@@ -417,8 +417,8 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention): ...@@ -417,8 +417,8 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
content_attention_mask: (default None) Optional mask that is added to content_attention_mask: (default None) Optional mask that is added to
content attention logits. If state is not None, the mask source sequence content attention logits. If state is not None, the mask source sequence
dimension should extend M. dimension should extend M.
query_attention_mask: (default None) Optional mask that is added to query_attention_mask: (default None) Optional mask that is added to query
query attention logits. If state is not None, the mask source sequence attention logits. If state is not None, the mask source sequence
dimension should extend M. dimension should extend M.
Returns: Returns:
...@@ -496,4 +496,3 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention): ...@@ -496,4 +496,3 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
query_attention_output = self._output_dense(query_attention_output) query_attention_output = self._output_dense(query_attention_output)
return content_attention_output, query_attention_output return content_attention_output, query_attention_output
...@@ -362,58 +362,61 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer): ...@@ -362,58 +362,61 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
if self._reuse_heads < self._num_heads: if self._reuse_heads < self._num_heads:
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2) free_dims, bound_dims=1, output_dims=2)
self._query_dense = tf.keras.layers.experimental.EinsumDense( self._query_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, [ output_shape=_get_output_shape(
self._num_heads - self._reuse_heads, self._key_dim]), output_rank - 1,
[self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="query", name="query",
kernel_initializer=tf_utils.clone_initializer( kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer), self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer( bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
self._bias_initializer),
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._key_shape.rank - 1, bound_dims=1, output_dims=2) self._key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = tf.keras.layers.experimental.EinsumDense( self._key_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, [ output_shape=_get_output_shape(
self._num_heads - self._reuse_heads, self._key_dim]), output_rank - 1,
[self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="key", name="key",
kernel_initializer=tf_utils.clone_initializer( kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer), self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer( bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
self._bias_initializer),
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._value_shape.rank - 1, bound_dims=1, output_dims=2) self._value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = [] self._value_dense = []
if self._reuse_heads > 0: if self._reuse_heads > 0:
self._value_dense.append(tf.keras.layers.experimental.EinsumDense( self._value_dense.append(
einsum_equation, tf.keras.layers.EinsumDense(
output_shape=_get_output_shape( einsum_equation,
output_rank - 1, [self._reuse_heads, self._value_dim]), output_shape=_get_output_shape(
bias_axes=bias_axes if self._use_bias else None, output_rank - 1, [self._reuse_heads, self._value_dim]),
name="value_reuse", bias_axes=bias_axes if self._use_bias else None,
kernel_initializer=tf_utils.clone_initializer( name="value_reuse",
self._kernel_initializer), kernel_initializer=tf_utils.clone_initializer(
bias_initializer=tf_utils.clone_initializer( self._kernel_initializer),
self._bias_initializer), bias_initializer=tf_utils.clone_initializer(
**common_kwargs)) self._bias_initializer),
**common_kwargs))
if self._reuse_heads < self._num_heads: if self._reuse_heads < self._num_heads:
self._value_dense.append(tf.keras.layers.experimental.EinsumDense( self._value_dense.append(
einsum_equation, tf.keras.layers.EinsumDense(
output_shape=_get_output_shape(output_rank - 1, [ einsum_equation,
self._num_heads - self._reuse_heads, self._value_dim]), output_shape=_get_output_shape(
bias_axes=bias_axes if self._use_bias else None, output_rank - 1,
name="value_new", [self._num_heads - self._reuse_heads, self._value_dim]),
kernel_initializer=tf_utils.clone_initializer( bias_axes=bias_axes if self._use_bias else None,
self._kernel_initializer), name="value_new",
bias_initializer=tf_utils.clone_initializer( kernel_initializer=tf_utils.clone_initializer(
self._bias_initializer), self._kernel_initializer),
**common_kwargs)) bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
# Builds the attention computations for multi-head dot product attention. # Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once # These computations could be wrapped into the keras attention layer once
...@@ -450,15 +453,13 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer): ...@@ -450,15 +453,13 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
output_shape = [self._query_shape[-1]] output_shape = [self._query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape)) free_dims, bound_dims=2, output_dims=len(output_shape))
return tf.keras.layers.experimental.EinsumDense( return tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape), output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if (use_bias and self._use_bias) else None, bias_axes=bias_axes if (use_bias and self._use_bias) else None,
name=name, name=name,
kernel_initializer=tf_utils.clone_initializer( kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
self._kernel_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs) **common_kwargs)
def _build_attention(self, rank): def _build_attention(self, rank):
......
...@@ -187,7 +187,7 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -187,7 +187,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
...@@ -205,7 +205,7 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -205,7 +205,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._inner_activation, dtype=policy) self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout( self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout) rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
......
...@@ -145,7 +145,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -145,7 +145,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=1e-12, epsilon=1e-12,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._intermediate_size), output_shape=(None, self._intermediate_size),
bias_axes="d", bias_axes="d",
...@@ -161,7 +161,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -161,7 +161,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._intermediate_activation, dtype=policy)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
......
...@@ -262,7 +262,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -262,7 +262,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
bias_initializer=tf_utils.clone_initializer(self._bias_initializer), bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense( self.self_attention_output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
...@@ -301,7 +301,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -301,7 +301,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
dtype="float32")) dtype="float32"))
# Feed-forward projection. # Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense( self.intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self.intermediate_size), output_shape=(None, self.intermediate_size),
bias_axes="d", bias_axes="d",
...@@ -313,7 +313,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -313,7 +313,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self.intermediate_activation) self.intermediate_activation)
self._intermediate_dropout_layer = tf.keras.layers.Dropout( self._intermediate_dropout_layer = tf.keras.layers.Dropout(
rate=self._intermediate_dropout) rate=self._intermediate_dropout)
self.output_dense = tf.keras.layers.experimental.EinsumDense( self.output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
......
...@@ -224,7 +224,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -224,7 +224,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
...@@ -242,7 +242,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -242,7 +242,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._inner_activation, dtype=policy) self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout( self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout) rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, last_output_shape), output_shape=(None, last_output_shape),
bias_axes="d", bias_axes="d",
......
...@@ -190,7 +190,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -190,7 +190,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
dtype=tf.float32)) dtype=tf.float32))
if self._feedforward_block is None: if self._feedforward_block is None:
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._intermediate_size), output_shape=(None, self._intermediate_size),
bias_axes="d", bias_axes="d",
...@@ -207,7 +207,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -207,7 +207,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
policy = tf.float32 policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation( self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy) self._intermediate_activation, dtype=policy)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
......
...@@ -58,7 +58,7 @@ class ValidatedFeedforwardLayer(tf.keras.layers.Layer): ...@@ -58,7 +58,7 @@ class ValidatedFeedforwardLayer(tf.keras.layers.Layer):
def build(self, input_shape): def build(self, input_shape):
hidden_size = input_shape.as_list()[-1] hidden_size = input_shape.as_list()[-1]
self._feedforward_dense = tf.keras.layers.experimental.EinsumDense( self._feedforward_dense = tf.keras.layers.EinsumDense(
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
......
...@@ -158,7 +158,7 @@ class TransformerXLBlock(tf.keras.layers.Layer): ...@@ -158,7 +158,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32) dtype=tf.float32)
self._inner_dense = tf.keras.layers.experimental.EinsumDense( self._inner_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._inner_size), output_shape=(None, self._inner_size),
bias_axes="d", bias_axes="d",
...@@ -169,7 +169,7 @@ class TransformerXLBlock(tf.keras.layers.Layer): ...@@ -169,7 +169,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
self._inner_activation) self._inner_activation)
self._inner_dropout_layer = tf.keras.layers.Dropout( self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout) rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
......
...@@ -124,7 +124,7 @@ class AlbertEncoder(tf.keras.Model): ...@@ -124,7 +124,7 @@ class AlbertEncoder(tf.keras.Model):
# We project the 'embedding' output to 'hidden_size' if it is not already # We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'. # 'hidden_size'.
if embedding_width != hidden_size: if embedding_width != hidden_size:
embeddings = tf.keras.layers.experimental.EinsumDense( embeddings = tf.keras.layers.EinsumDense(
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
......
...@@ -150,7 +150,7 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -150,7 +150,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
# 'hidden_size'. # 'hidden_size'.
self._embedding_projection = None self._embedding_projection = None
if embedding_width != hidden_size: if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense( self._embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
...@@ -442,7 +442,7 @@ class BertEncoder(tf.keras.Model): ...@@ -442,7 +442,7 @@ class BertEncoder(tf.keras.Model):
# We project the 'embedding' output to 'hidden_size' if it is not already # We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'. # 'hidden_size'.
if embedding_width != hidden_size: if embedding_width != hidden_size:
embedding_projection = tf.keras.layers.experimental.EinsumDense( embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
......
...@@ -293,7 +293,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -293,7 +293,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
# 'hidden_size'. # 'hidden_size'.
self._embedding_projection = None self._embedding_projection = None
if embedding_width != hidden_size: if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense( self._embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
......
...@@ -146,7 +146,7 @@ class MobileBERTEncoder(tf.keras.Model): ...@@ -146,7 +146,7 @@ class MobileBERTEncoder(tf.keras.Model):
first_token = tf.squeeze(prev_output[:, 0:1, :], axis=1) first_token = tf.squeeze(prev_output[:, 0:1, :], axis=1)
if classifier_activation: if classifier_activation:
self._pooler_layer = tf.keras.layers.experimental.EinsumDense( self._pooler_layer = tf.keras.layers.EinsumDense(
'ab,bc->ac', 'ab,bc->ac',
output_shape=hidden_size, output_shape=hidden_size,
activation=tf.tanh, activation=tf.tanh,
......
...@@ -128,7 +128,7 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -128,7 +128,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
embeddings) embeddings)
if embedding_width != hidden_size: if embedding_width != hidden_size:
embeddings = tf.keras.layers.experimental.EinsumDense( embeddings = tf.keras.layers.EinsumDense(
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes=None, bias_axes=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment