Commit 37c9f3d3 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Consolidate the self attention mask layer.

Ask users to use keras_nlp.

PiperOrigin-RevId: 346621226
parent 36db2450
......@@ -15,13 +15,15 @@
"""Keras layer that creates a self-attention mask."""
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.keras_nlp import layers
@tf.keras.utils.register_keras_serializable(package='Text')
class SelfAttentionMask(tf.keras.layers.Layer):
class SelfAttentionMask(layers.SelfAttentionMask):
"""Create 3D attention mask from a 2D tensor mask.
**Warning: Please use the keras_nlp.layers.SelfAttentionMask.**
inputs[0]: from_tensor: 2D or 3D Tensor of shape
[batch_size, from_seq_length, ...].
inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length].
......@@ -31,28 +33,7 @@ class SelfAttentionMask(tf.keras.layers.Layer):
"""
def call(self, inputs):
from_tensor = inputs[0]
to_mask = inputs[1]
from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
dtype=from_tensor.dtype)
# We don't assume that `from_tensor` is a mask (although it could be). We
# don't actually care if we attend *from* padding tokens (only *to* padding)
# tokens so we create a tensor of all ones.
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones = tf.ones(
shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype)
# Here we broadcast along two dimensions to create the mask.
mask = broadcast_ones * to_mask
return mask
if isinstance(inputs, list):
return super().call(inputs[0], inputs[1])
else:
return super().call(inputs)
......@@ -133,7 +133,7 @@ class AlbertEncoder(tf.keras.Model):
embeddings)
data = embeddings
attention_mask = layers.SelfAttentionMask()([data, mask])
attention_mask = keras_nlp.layers.SelfAttentionMask()(data, mask)
shared_layer = keras_nlp.layers.TransformerEncoderBlock(
num_attention_heads=num_attention_heads,
inner_dim=intermediate_size,
......
......@@ -169,7 +169,7 @@ class EncoderScaffold(tf.keras.Model):
tf.keras.layers.Dropout(
rate=embedding_cfg['dropout_rate'])(embeddings))
attention_mask = layers.SelfAttentionMask()([embeddings, mask])
attention_mask = keras_nlp.layers.SelfAttentionMask()(embeddings, mask)
data = embeddings
......
......@@ -16,6 +16,7 @@
import gin
import tensorflow as tf
from official.nlp import keras_nlp
from official.nlp.modeling import layers
......@@ -118,7 +119,7 @@ class MobileBERTEncoder(tf.keras.Model):
type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids')
self.inputs = [input_ids, input_mask, type_ids]
attention_mask = layers.SelfAttentionMask()([input_ids, input_mask])
attention_mask = keras_nlp.layers.SelfAttentionMask()(input_ids, input_mask)
# build the computation graph
all_layer_outputs = []
......
......@@ -19,6 +19,7 @@ import collections
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp import keras_nlp
from official.nlp.modeling import layers
......@@ -139,7 +140,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
name='embedding_projection')(
embeddings)
attention_mask = layers.SelfAttentionMask()([embeddings, mask])
attention_mask = keras_nlp.layers.SelfAttentionMask()(embeddings, mask)
if sub_seq_mask is not None:
attention_mask = tf.keras.layers.Lambda(
lambda x: x[0] * tf.cast(x[1], x[0].dtype))(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment