Commit 7c3fd62e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Migrate to tf.keras.layers.EinsumDense after TF 2.9

PiperOrigin-RevId: 453749903
parent 552ca56f
...@@ -173,7 +173,7 @@ class BigBirdEncoder(tf.keras.Model): ...@@ -173,7 +173,7 @@ class BigBirdEncoder(tf.keras.Model):
# We project the 'embedding' output to 'hidden_size' if it is not already # We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'. # 'hidden_size'.
if embedding_width != hidden_size: if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense( self._embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
......
...@@ -283,7 +283,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -283,7 +283,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
...@@ -300,7 +300,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -300,7 +300,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._inner_activation, dtype=policy) self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout( self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout) rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
...@@ -691,7 +691,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -691,7 +691,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
kernel_initializer=self._attention_initializer, kernel_initializer=self._attention_initializer,
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense( self.self_attention_output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
...@@ -727,7 +727,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -727,7 +727,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
dtype="float32")) dtype="float32"))
# Feed-forward projection. # Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense( self.intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self.intermediate_size), output_shape=(None, self.intermediate_size),
bias_axes="d", bias_axes="d",
...@@ -738,7 +738,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -738,7 +738,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self.intermediate_activation) self.intermediate_activation)
self._intermediate_dropout_layer = tf.keras.layers.Dropout( self._intermediate_dropout_layer = tf.keras.layers.Dropout(
rate=self._intermediate_dropout) rate=self._intermediate_dropout)
self.output_dense = tf.keras.layers.experimental.EinsumDense( self.output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
......
...@@ -161,7 +161,7 @@ class MobileBERTEncoder(tf.keras.Model): ...@@ -161,7 +161,7 @@ class MobileBERTEncoder(tf.keras.Model):
first_token = tf.squeeze(prev_output[:, 0:1, :], axis=1) first_token = tf.squeeze(prev_output[:, 0:1, :], axis=1)
if classifier_activation: if classifier_activation:
self._pooler_layer = tf.keras.layers.experimental.EinsumDense( self._pooler_layer = tf.keras.layers.EinsumDense(
'ab,bc->ac', 'ab,bc->ac',
output_shape=hidden_size, output_shape=hidden_size,
activation=tf.tanh, activation=tf.tanh,
......
...@@ -170,14 +170,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -170,14 +170,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
free_dims = self._query_shape.rank - 1 free_dims = self._query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2) free_dims, bound_dims=1, output_dims=2)
self._query_dense = tf.keras.layers.experimental.EinsumDense( self._query_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]), [self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="query", name="query",
**common_kwargs) **common_kwargs)
self._global_query_dense = tf.keras.layers.experimental.EinsumDense( self._global_query_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]), [self._num_heads, self._key_dim]),
...@@ -186,14 +186,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -186,14 +186,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._key_shape.rank - 1, bound_dims=1, output_dims=2) self._key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = tf.keras.layers.experimental.EinsumDense( self._key_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]), [self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="key", name="key",
**common_kwargs) **common_kwargs)
self._global_key_dense = tf.keras.layers.experimental.EinsumDense( self._global_key_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]), [self._num_heads, self._key_dim]),
...@@ -202,14 +202,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention): ...@@ -202,14 +202,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._value_shape.rank - 1, bound_dims=1, output_dims=2) self._value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = tf.keras.layers.experimental.EinsumDense( self._value_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_dim]), [self._num_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="value", name="value",
**common_kwargs) **common_kwargs)
self._global_value_dense = tf.keras.layers.experimental.EinsumDense( self._global_value_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_dim]), [self._num_heads, self._value_dim]),
......
...@@ -136,7 +136,7 @@ class LongformerEncoder(tf.keras.layers.Layer): ...@@ -136,7 +136,7 @@ class LongformerEncoder(tf.keras.layers.Layer):
# 'hidden_size'. # 'hidden_size'.
self._embedding_projection = None self._embedding_projection = None
if embedding_width != hidden_size: if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense( self._embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
......
...@@ -166,7 +166,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -166,7 +166,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
dtype=tf.float32)) dtype=tf.float32))
# TFLongformerIntermediate # TFLongformerIntermediate
# TFLongformerIntermediate.dense # TFLongformerIntermediate.dense
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
...@@ -186,7 +186,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer): ...@@ -186,7 +186,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
rate=self._inner_dropout) rate=self._inner_dropout)
# TFLongformerOutput # TFLongformerOutput
# TFLongformerOutput.dense # TFLongformerOutput.dense
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
import tensorflow as tf import tensorflow as tf
EinsumDense = tf.keras.layers.experimental.EinsumDense EinsumDense = tf.keras.layers.EinsumDense
MultiHeadAttention = tf.keras.layers.MultiHeadAttention MultiHeadAttention = tf.keras.layers.MultiHeadAttention
......
...@@ -143,7 +143,7 @@ class RoformerEncoder(tf.keras.Model): ...@@ -143,7 +143,7 @@ class RoformerEncoder(tf.keras.Model):
# We project the 'embedding' output to 'hidden_size' if it is not already # We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'. # 'hidden_size'.
if embedding_width != hidden_size: if embedding_width != hidden_size:
embedding_projection = tf.keras.layers.experimental.EinsumDense( embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
......
...@@ -162,7 +162,7 @@ class RoformerEncoderBlock(tf.keras.layers.Layer): ...@@ -162,7 +162,7 @@ class RoformerEncoderBlock(tf.keras.layers.Layer):
axis=-1, axis=-1,
epsilon=self._norm_epsilon, epsilon=self._norm_epsilon,
dtype=tf.float32)) dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense( self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
...@@ -179,7 +179,7 @@ class RoformerEncoderBlock(tf.keras.layers.Layer): ...@@ -179,7 +179,7 @@ class RoformerEncoderBlock(tf.keras.layers.Layer):
self._inner_activation, dtype=policy) self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout( self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout) rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense( self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation, einsum_equation,
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
......
...@@ -65,9 +65,9 @@ class TokenDropBertEncoder(tf.keras.layers.Layer): ...@@ -65,9 +65,9 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
token_keep_k: The number of tokens you want to keep in the intermediate token_keep_k: The number of tokens you want to keep in the intermediate
layers. The rest will be dropped in those layers. layers. The rest will be dropped in those layers.
token_allow_list: The list of token-ids that should not be droped. In the token_allow_list: The list of token-ids that should not be droped. In the
BERT English vocab, token-id from 1 to 998 contains special tokens such BERT English vocab, token-id from 1 to 998 contains special tokens such as
as [CLS], [SEP]. By default, token_allow_list contains all of these [CLS], [SEP]. By default, token_allow_list contains all of these special
special tokens. tokens.
token_deny_list: The list of token-ids that should always be droped. In the token_deny_list: The list of token-ids that should always be droped. In the
BERT English vocab, token-id=0 means [PAD]. By default, token_deny_list BERT English vocab, token-id=0 means [PAD]. By default, token_deny_list
contains and only contains [PAD]. contains and only contains [PAD].
...@@ -166,7 +166,7 @@ class TokenDropBertEncoder(tf.keras.layers.Layer): ...@@ -166,7 +166,7 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
# 'hidden_size'. # 'hidden_size'.
self._embedding_projection = None self._embedding_projection = None
if embedding_width != hidden_size: if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense( self._embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
...@@ -386,4 +386,3 @@ class TokenDropBertEncoder(tf.keras.layers.Layer): ...@@ -386,4 +386,3 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
logging.warn(warn_string) logging.warn(warn_string)
return cls(**config) return cls(**config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment