Commit 5fc4c351 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Consolidate the self attention mask layer.

Ask users to use keras_nlp.

PiperOrigin-RevId: 346621226
parent 74438979
...@@ -15,13 +15,15 @@ ...@@ -15,13 +15,15 @@
"""Keras layer that creates a self-attention mask.""" """Keras layer that creates a self-attention mask."""
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.keras_nlp import layers
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class SelfAttentionMask(tf.keras.layers.Layer): class SelfAttentionMask(layers.SelfAttentionMask):
"""Create 3D attention mask from a 2D tensor mask. """Create 3D attention mask from a 2D tensor mask.
**Warning: Please use the keras_nlp.layers.SelfAttentionMask.**
inputs[0]: from_tensor: 2D or 3D Tensor of shape inputs[0]: from_tensor: 2D or 3D Tensor of shape
[batch_size, from_seq_length, ...]. [batch_size, from_seq_length, ...].
inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length]. inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length].
...@@ -31,28 +33,7 @@ class SelfAttentionMask(tf.keras.layers.Layer): ...@@ -31,28 +33,7 @@ class SelfAttentionMask(tf.keras.layers.Layer):
""" """
def call(self, inputs): def call(self, inputs):
from_tensor = inputs[0] if isinstance(inputs, list):
to_mask = inputs[1] return super().call(inputs[0], inputs[1])
from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3]) else:
batch_size = from_shape[0] return super().call(inputs)
from_seq_length = from_shape[1]
to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
dtype=from_tensor.dtype)
# We don't assume that `from_tensor` is a mask (although it could be). We
# don't actually care if we attend *from* padding tokens (only *to* padding)
# tokens so we create a tensor of all ones.
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones = tf.ones(
shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype)
# Here we broadcast along two dimensions to create the mask.
mask = broadcast_ones * to_mask
return mask
...@@ -133,7 +133,7 @@ class AlbertEncoder(tf.keras.Model): ...@@ -133,7 +133,7 @@ class AlbertEncoder(tf.keras.Model):
embeddings) embeddings)
data = embeddings data = embeddings
attention_mask = layers.SelfAttentionMask()([data, mask]) attention_mask = keras_nlp.layers.SelfAttentionMask()(data, mask)
shared_layer = keras_nlp.layers.TransformerEncoderBlock( shared_layer = keras_nlp.layers.TransformerEncoderBlock(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
inner_dim=intermediate_size, inner_dim=intermediate_size,
......
...@@ -169,7 +169,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -169,7 +169,7 @@ class EncoderScaffold(tf.keras.Model):
tf.keras.layers.Dropout( tf.keras.layers.Dropout(
rate=embedding_cfg['dropout_rate'])(embeddings)) rate=embedding_cfg['dropout_rate'])(embeddings))
attention_mask = layers.SelfAttentionMask()([embeddings, mask]) attention_mask = keras_nlp.layers.SelfAttentionMask()(embeddings, mask)
data = embeddings data = embeddings
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.nlp import keras_nlp
from official.nlp.modeling import layers from official.nlp.modeling import layers
...@@ -118,7 +119,7 @@ class MobileBERTEncoder(tf.keras.Model): ...@@ -118,7 +119,7 @@ class MobileBERTEncoder(tf.keras.Model):
type_ids = tf.keras.layers.Input( type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids') shape=(None,), dtype=tf.int32, name='input_type_ids')
self.inputs = [input_ids, input_mask, type_ids] self.inputs = [input_ids, input_mask, type_ids]
attention_mask = layers.SelfAttentionMask()([input_ids, input_mask]) attention_mask = keras_nlp.layers.SelfAttentionMask()(input_ids, input_mask)
# build the computation graph # build the computation graph
all_layer_outputs = [] all_layer_outputs = []
......
...@@ -19,6 +19,7 @@ import collections ...@@ -19,6 +19,7 @@ import collections
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp import keras_nlp
from official.nlp.modeling import layers from official.nlp.modeling import layers
...@@ -139,7 +140,7 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -139,7 +140,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
name='embedding_projection')( name='embedding_projection')(
embeddings) embeddings)
attention_mask = layers.SelfAttentionMask()([embeddings, mask]) attention_mask = keras_nlp.layers.SelfAttentionMask()(embeddings, mask)
if sub_seq_mask is not None: if sub_seq_mask is not None:
attention_mask = tf.keras.layers.Lambda( attention_mask = tf.keras.layers.Lambda(
lambda x: x[0] * tf.cast(x[1], x[0].dtype))( lambda x: x[0] * tf.cast(x[1], x[0].dtype))(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment