Commit 7c3fd62e authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Migrate to tf.keras.layers.EinsumDense after TF 2.9

PiperOrigin-RevId: 453749903
parent 552ca56f
......@@ -173,7 +173,7 @@ class BigBirdEncoder(tf.keras.Model):
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
self._embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
......
......@@ -283,7 +283,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
......@@ -300,7 +300,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
......@@ -691,7 +691,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
kernel_initializer=self._attention_initializer,
name="self_attention",
**common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
self.self_attention_output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
......@@ -727,7 +727,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
dtype="float32"))
# Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
self.intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self.intermediate_size),
bias_axes="d",
......@@ -738,7 +738,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self.intermediate_activation)
self._intermediate_dropout_layer = tf.keras.layers.Dropout(
rate=self._intermediate_dropout)
self.output_dense = tf.keras.layers.experimental.EinsumDense(
self.output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
......
......@@ -161,7 +161,7 @@ class MobileBERTEncoder(tf.keras.Model):
first_token = tf.squeeze(prev_output[:, 0:1, :], axis=1)
if classifier_activation:
self._pooler_layer = tf.keras.layers.experimental.EinsumDense(
self._pooler_layer = tf.keras.layers.EinsumDense(
'ab,bc->ac',
output_shape=hidden_size,
activation=tf.tanh,
......
......@@ -170,14 +170,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
free_dims = self._query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2)
self._query_dense = tf.keras.layers.experimental.EinsumDense(
self._query_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="query",
**common_kwargs)
self._global_query_dense = tf.keras.layers.experimental.EinsumDense(
self._global_query_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
......@@ -186,14 +186,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = tf.keras.layers.experimental.EinsumDense(
self._key_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="key",
**common_kwargs)
self._global_key_dense = tf.keras.layers.experimental.EinsumDense(
self._global_key_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
......@@ -202,14 +202,14 @@ class LongformerAttention(tf.keras.layers.MultiHeadAttention):
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = tf.keras.layers.experimental.EinsumDense(
self._value_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value",
**common_kwargs)
self._global_value_dense = tf.keras.layers.experimental.EinsumDense(
self._global_value_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_dim]),
......
......@@ -136,7 +136,7 @@ class LongformerEncoder(tf.keras.layers.Layer):
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
self._embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
......
......@@ -166,7 +166,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
dtype=tf.float32))
# TFLongformerIntermediate
# TFLongformerIntermediate.dense
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
......@@ -186,7 +186,7 @@ class LongformerEncoderBlock(tf.keras.layers.Layer):
rate=self._inner_dropout)
# TFLongformerOutput
# TFLongformerOutput.dense
self._output_dense = tf.keras.layers.experimental.EinsumDense(
self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
......
......@@ -16,7 +16,7 @@
# pylint: disable=g-classes-have-attributes
import tensorflow as tf
EinsumDense = tf.keras.layers.experimental.EinsumDense
EinsumDense = tf.keras.layers.EinsumDense
MultiHeadAttention = tf.keras.layers.MultiHeadAttention
......
......@@ -143,7 +143,7 @@ class RoformerEncoder(tf.keras.Model):
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if embedding_width != hidden_size:
embedding_projection = tf.keras.layers.experimental.EinsumDense(
embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
......
......@@ -162,7 +162,7 @@ class RoformerEncoderBlock(tf.keras.layers.Layer):
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
......@@ -179,7 +179,7 @@ class RoformerEncoderBlock(tf.keras.layers.Layer):
self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
......
......@@ -65,9 +65,9 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
token_keep_k: The number of tokens you want to keep in the intermediate
layers. The rest will be dropped in those layers.
token_allow_list: The list of token-ids that should not be droped. In the
BERT English vocab, token-id from 1 to 998 contains special tokens such
as [CLS], [SEP]. By default, token_allow_list contains all of these
special tokens.
BERT English vocab, token-id from 1 to 998 contains special tokens such as
[CLS], [SEP]. By default, token_allow_list contains all of these special
tokens.
token_deny_list: The list of token-ids that should always be droped. In the
BERT English vocab, token-id=0 means [PAD]. By default, token_deny_list
contains and only contains [PAD].
......@@ -166,7 +166,7 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
self._embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
......@@ -386,4 +386,3 @@ class TokenDropBertEncoder(tf.keras.layers.Layer):
logging.warn(warn_string)
return cls(**config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment