"launcher/vscode:/vscode.git/clone" did not exist on "2358c2bb54cf60a1596267192451a46a24f03e06"
Commit a506d3a4 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Migrate to tf.keras.layers.EinsumDense after TF 2.9

PiperOrigin-RevId: 453749903
parent 1738af44
......@@ -52,19 +52,19 @@ class Attention(tf.keras.layers.Layer):
attention_initializer = _glorot_initializer(input_shape.as_list()[-1],
self.hidden_size)
self.query_dense_layer = tf.keras.layers.experimental.EinsumDense(
self.query_dense_layer = tf.keras.layers.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None,
name="query")
self.key_dense_layer = tf.keras.layers.experimental.EinsumDense(
self.key_dense_layer = tf.keras.layers.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=tf_utils.clone_initializer(attention_initializer),
bias_axes=None,
name="key")
self.value_dense_layer = tf.keras.layers.experimental.EinsumDense(
self.value_dense_layer = tf.keras.layers.EinsumDense(
"BTE,ENH->BTNH",
output_shape=(None, self.num_heads, size_per_head),
kernel_initializer=tf_utils.clone_initializer(attention_initializer),
......@@ -72,7 +72,7 @@ class Attention(tf.keras.layers.Layer):
name="value")
output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size)
self.output_dense_layer = tf.keras.layers.experimental.EinsumDense(
self.output_dense_layer = tf.keras.layers.EinsumDense(
"BTNH,NHE->BTE",
output_shape=(None, self.hidden_size),
kernel_initializer=output_initializer,
......
......@@ -18,7 +18,7 @@ import math
import tensorflow as tf
EinsumDense = tf.keras.layers.experimental.EinsumDense
EinsumDense = tf.keras.layers.EinsumDense
MultiHeadAttention = tf.keras.layers.MultiHeadAttention
......
......@@ -88,7 +88,7 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
self._intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cde->abde",
output_shape=(None, self._num_blocks,
self._intermediate_size // self._num_blocks),
......@@ -106,10 +106,9 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
self._output_dense = tf.keras.layers.EinsumDense(
"abde,deo->abdo",
output_shape=(None, self._num_blocks,
hidden_size // self._num_blocks),
output_shape=(None, self._num_blocks, hidden_size // self._num_blocks),
bias_axes="do",
name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
......@@ -117,7 +116,7 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
**common_kwargs)
if self._apply_mixing:
self._output_mixing = tf.keras.layers.experimental.EinsumDense(
self._output_mixing = tf.keras.layers.EinsumDense(
"abdo,de->abeo",
output_shape=(None, self._num_blocks,
hidden_size // self._num_blocks),
......
......@@ -116,7 +116,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
activation_policy = tf.float32
for i in range(self._num_blocks):
self._intermediate_dense.append(
tf.keras.layers.experimental.EinsumDense(
tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
bias_axes="d",
......@@ -131,7 +131,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
self._intermediate_activation, dtype=activation_policy))
if self._use_gate:
self._gate_dense.append(
tf.keras.layers.experimental.EinsumDense(
tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
bias_axes="d",
......@@ -142,7 +142,7 @@ class GatedFeedforward(tf.keras.layers.Layer):
self._bias_initializer),
**common_kwargs))
self._output_dense.append(
tf.keras.layers.experimental.EinsumDense(
tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
......
......@@ -122,7 +122,7 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
max_length=max_sequence_length,
initializer=tf_utils.clone_initializer(self.initializer),
name='position_embedding')
self.word_embedding_proj = tf.keras.layers.experimental.EinsumDense(
self.word_embedding_proj = tf.keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.output_embed_size],
kernel_initializer=tf_utils.clone_initializer(self.initializer),
......@@ -244,7 +244,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
self.block_layers = {}
# add input bottleneck
dense_layer_2d = tf.keras.layers.experimental.EinsumDense(
dense_layer_2d = tf.keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
......@@ -256,7 +256,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
layer_norm]
if self.key_query_shared_bottleneck:
dense_layer_2d = tf.keras.layers.experimental.EinsumDense(
dense_layer_2d = tf.keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
......@@ -286,7 +286,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
for ffn_layer_idx in range(self.num_feedforward_networks):
layer_prefix = f'ffn_layer_{ffn_layer_idx}'
layer_name = layer_prefix + '/intermediate_dense'
intermediate_layer = tf.keras.layers.experimental.EinsumDense(
intermediate_layer = tf.keras.layers.EinsumDense(
'abc,cd->abd',
activation=self.intermediate_act_fn,
output_shape=[None, self.intermediate_size],
......@@ -294,7 +294,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=layer_name)
layer_name = layer_prefix + '/output_dense'
output_layer = tf.keras.layers.experimental.EinsumDense(
output_layer = tf.keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
......@@ -308,7 +308,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
layer_norm])
# add output bottleneck
bottleneck = tf.keras.layers.experimental.EinsumDense(
bottleneck = tf.keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.hidden_size],
activation=None,
......
......@@ -66,7 +66,7 @@ class VotingAttention(tf.keras.layers.Layer):
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._query_dense = tf.keras.layers.experimental.EinsumDense(
self._query_dense = tf.keras.layers.EinsumDense(
"BAE,ENH->BANH",
output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH",
......@@ -74,7 +74,7 @@ class VotingAttention(tf.keras.layers.Layer):
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._key_dense = tf.keras.layers.experimental.EinsumDense(
self._key_dense = tf.keras.layers.EinsumDense(
"BAE,ENH->BANH",
output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH",
......
......@@ -98,14 +98,14 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
`[B, L, dim]`.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet of shape `[B, S, S + M]`.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet of shape `[2, num_heads, dim]`.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet of shape `[num_heads, dim]`.
segment_encoding: Optional `Tensor` representing the segmentation encoding
as used in XLNet of shape `[2, num_heads, dim]`.
segment_attention_bias: Optional trainable bias parameter added to the query
had when calculating the segment-based attention score used in XLNet of
shape `[num_heads, dim]`.
state: Optional `Tensor` of shape `[B, M, E]` where M is the length of the
state or memory.
If passed, this is also attended over as in Transformer XL.
state or memory. If passed, this is also attended over as in Transformer
XL.
attention_mask: A boolean mask of shape `[B, T, S]` that prevents attention
to certain positions.
"""
......@@ -144,7 +144,7 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
with tf.init_scope():
einsum_equation, _, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2)
self._encoding_dense = tf.keras.layers.experimental.EinsumDense(
self._encoding_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
......@@ -255,8 +255,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
Args:
query: attention input.
value: attention input.
content_attention_bias: A trainable bias parameter added to the query
head when calculating the content-based attention score.
content_attention_bias: A trainable bias parameter added to the query head
when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score.
key: attention input.
......@@ -264,8 +264,8 @@ class MultiHeadRelativeAttention(tf.keras.layers.MultiHeadAttention):
value.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet.
segment_encoding: Optional `Tensor` representing the segmentation encoding
as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query had when calculating the segment-based attention score used in
XLNet.
......@@ -394,22 +394,22 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
content_stream: The content representation, commonly referred to as h.
This serves a similar role to the standard hidden states in
Transformer-XL.
content_attention_bias: A trainable bias parameter added to the query
head when calculating the content-based attention score.
content_attention_bias: A trainable bias parameter added to the query head
when calculating the content-based attention score.
positional_attention_bias: A trainable bias parameter added to the query
head when calculating the position-based attention score.
query_stream: The query representation, commonly referred to as g.
This only has access to contextual information and position, but not
content. If not provided, then this is MultiHeadRelativeAttention with
query_stream: The query representation, commonly referred to as g. This
only has access to contextual information and position, but not content.
If not provided, then this is MultiHeadRelativeAttention with
self-attention.
relative_position_encoding: relative positional encoding for key and
value.
target_mapping: Optional `Tensor` representing the target mapping used
in partial prediction.
target_mapping: Optional `Tensor` representing the target mapping used in
partial prediction.
segment_matrix: Optional `Tensor` representing segmentation IDs used in
XLNet.
segment_encoding: Optional `Tensor` representing the segmentation
encoding as used in XLNet.
segment_encoding: Optional `Tensor` representing the segmentation encoding
as used in XLNet.
segment_attention_bias: Optional trainable bias parameter added to the
query head when calculating the segment-based attention score.
state: (default None) optional state. If passed, this is also attended
......@@ -417,8 +417,8 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
content_attention_mask: (default None) Optional mask that is added to
content attention logits. If state is not None, the mask source sequence
dimension should extend M.
query_attention_mask: (default None) Optional mask that is added to
query attention logits. If state is not None, the mask source sequence
query_attention_mask: (default None) Optional mask that is added to query
attention logits. If state is not None, the mask source sequence
dimension should extend M.
Returns:
......@@ -496,4 +496,3 @@ class TwoStreamRelativeAttention(MultiHeadRelativeAttention):
query_attention_output = self._output_dense(query_attention_output)
return content_attention_output, query_attention_output
......@@ -362,36 +362,37 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
if self._reuse_heads < self._num_heads:
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2)
self._query_dense = tf.keras.layers.experimental.EinsumDense(
self._query_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, [
self._num_heads - self._reuse_heads, self._key_dim]),
output_shape=_get_output_shape(
output_rank - 1,
[self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="query",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = tf.keras.layers.experimental.EinsumDense(
self._key_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, [
self._num_heads - self._reuse_heads, self._key_dim]),
output_shape=_get_output_shape(
output_rank - 1,
[self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="key",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = []
if self._reuse_heads > 0:
self._value_dense.append(tf.keras.layers.experimental.EinsumDense(
self._value_dense.append(
tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._reuse_heads, self._value_dim]),
......@@ -403,10 +404,12 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._bias_initializer),
**common_kwargs))
if self._reuse_heads < self._num_heads:
self._value_dense.append(tf.keras.layers.experimental.EinsumDense(
self._value_dense.append(
tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, [
self._num_heads - self._reuse_heads, self._value_dim]),
output_shape=_get_output_shape(
output_rank - 1,
[self._num_heads - self._reuse_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value_new",
kernel_initializer=tf_utils.clone_initializer(
......@@ -450,15 +453,13 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
output_shape = [self._query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape))
return tf.keras.layers.experimental.EinsumDense(
return tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if (use_bias and self._use_bias) else None,
name=name,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
def _build_attention(self, rank):
......
......@@ -187,7 +187,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
......@@ -205,7 +205,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
......
......@@ -145,7 +145,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
axis=-1,
epsilon=1e-12,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
self._intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
bias_axes="d",
......@@ -161,7 +161,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
......
......@@ -262,7 +262,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="self_attention",
**common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
self.self_attention_output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
......@@ -301,7 +301,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
dtype="float32"))
# Feed-forward projection.
self.intermediate_dense = tf.keras.layers.experimental.EinsumDense(
self.intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self.intermediate_size),
bias_axes="d",
......@@ -313,7 +313,7 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self.intermediate_activation)
self._intermediate_dropout_layer = tf.keras.layers.Dropout(
rate=self._intermediate_dropout)
self.output_dense = tf.keras.layers.experimental.EinsumDense(
self.output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
......
......@@ -224,7 +224,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
self._intermediate_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
......@@ -242,7 +242,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
self._output_dense = tf.keras.layers.EinsumDense(
einsum_equation,
output_shape=(None, last_output_shape),
bias_axes="d",
......
......@@ -190,7 +190,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
dtype=tf.float32))
if self._feedforward_block is None:
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
self._intermediate_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._intermediate_size),
bias_axes="d",
......@@ -207,7 +207,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._intermediate_activation, dtype=policy)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
......
......@@ -58,7 +58,7 @@ class ValidatedFeedforwardLayer(tf.keras.layers.Layer):
def build(self, input_shape):
hidden_size = input_shape.as_list()[-1]
self._feedforward_dense = tf.keras.layers.experimental.EinsumDense(
self._feedforward_dense = tf.keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
......
......@@ -158,7 +158,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
self._inner_dense = tf.keras.layers.experimental.EinsumDense(
self._inner_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, self._inner_size),
bias_axes="d",
......@@ -169,7 +169,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
self._inner_activation)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
self._output_dense = tf.keras.layers.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
......
......@@ -124,7 +124,7 @@ class AlbertEncoder(tf.keras.Model):
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if embedding_width != hidden_size:
embeddings = tf.keras.layers.experimental.EinsumDense(
embeddings = tf.keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
......
......@@ -150,7 +150,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
self._embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
......@@ -442,7 +442,7 @@ class BertEncoder(tf.keras.Model):
# We project the 'embedding' output to 'hidden_size' if it is not already
# 'hidden_size'.
if embedding_width != hidden_size:
embedding_projection = tf.keras.layers.experimental.EinsumDense(
embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
......
......@@ -293,7 +293,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
# 'hidden_size'.
self._embedding_projection = None
if embedding_width != hidden_size:
self._embedding_projection = tf.keras.layers.experimental.EinsumDense(
self._embedding_projection = tf.keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
......
......@@ -146,7 +146,7 @@ class MobileBERTEncoder(tf.keras.Model):
first_token = tf.squeeze(prev_output[:, 0:1, :], axis=1)
if classifier_activation:
self._pooler_layer = tf.keras.layers.experimental.EinsumDense(
self._pooler_layer = tf.keras.layers.EinsumDense(
'ab,bc->ac',
output_shape=hidden_size,
activation=tf.tanh,
......
......@@ -128,7 +128,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
embeddings)
if embedding_width != hidden_size:
embeddings = tf.keras.layers.experimental.EinsumDense(
embeddings = tf.keras.layers.EinsumDense(
'...x,xy->...y',
output_shape=hidden_size,
bias_axes=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment