"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "3ab521fa2b26b6dbfdb87d36739f0e4191a500a1"
Commit 5fc4c351 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Consolidate the self attention mask layer.

Ask users to use keras_nlp.

PiperOrigin-RevId: 346621226
parent 74438979
...@@ -15,13 +15,15 @@ ...@@ -15,13 +15,15 @@
"""Keras layer that creates a self-attention mask.""" """Keras layer that creates a self-attention mask."""
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.keras_nlp import layers
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class SelfAttentionMask(tf.keras.layers.Layer): class SelfAttentionMask(layers.SelfAttentionMask):
"""Create 3D attention mask from a 2D tensor mask. """Create 3D attention mask from a 2D tensor mask.
**Warning: Please use the keras_nlp.layers.SelfAttentionMask.**
inputs[0]: from_tensor: 2D or 3D Tensor of shape inputs[0]: from_tensor: 2D or 3D Tensor of shape
[batch_size, from_seq_length, ...]. [batch_size, from_seq_length, ...].
inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length]. inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length].
...@@ -31,28 +33,7 @@ class SelfAttentionMask(tf.keras.layers.Layer): ...@@ -31,28 +33,7 @@ class SelfAttentionMask(tf.keras.layers.Layer):
""" """
def call(self, inputs): def call(self, inputs):
from_tensor = inputs[0] if isinstance(inputs, list):
to_mask = inputs[1] return super().call(inputs[0], inputs[1])
from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3]) else:
batch_size = from_shape[0] return super().call(inputs)
from_seq_length = from_shape[1]
to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
dtype=from_tensor.dtype)
# We don't assume that `from_tensor` is a mask (although it could be). We
# don't actually care if we attend *from* padding tokens (only *to* padding)
# tokens so we create a tensor of all ones.
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones = tf.ones(
shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype)
# Here we broadcast along two dimensions to create the mask.
mask = broadcast_ones * to_mask
return mask
...@@ -133,7 +133,7 @@ class AlbertEncoder(tf.keras.Model): ...@@ -133,7 +133,7 @@ class AlbertEncoder(tf.keras.Model):
embeddings) embeddings)
data = embeddings data = embeddings
attention_mask = layers.SelfAttentionMask()([data, mask]) attention_mask = keras_nlp.layers.SelfAttentionMask()(data, mask)
shared_layer = keras_nlp.layers.TransformerEncoderBlock( shared_layer = keras_nlp.layers.TransformerEncoderBlock(
num_attention_heads=num_attention_heads, num_attention_heads=num_attention_heads,
inner_dim=intermediate_size, inner_dim=intermediate_size,
......
...@@ -169,7 +169,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -169,7 +169,7 @@ class EncoderScaffold(tf.keras.Model):
tf.keras.layers.Dropout( tf.keras.layers.Dropout(
rate=embedding_cfg['dropout_rate'])(embeddings)) rate=embedding_cfg['dropout_rate'])(embeddings))
attention_mask = layers.SelfAttentionMask()([embeddings, mask]) attention_mask = keras_nlp.layers.SelfAttentionMask()(embeddings, mask)
data = embeddings data = embeddings
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.nlp import keras_nlp
from official.nlp.modeling import layers from official.nlp.modeling import layers
...@@ -118,7 +119,7 @@ class MobileBERTEncoder(tf.keras.Model): ...@@ -118,7 +119,7 @@ class MobileBERTEncoder(tf.keras.Model):
type_ids = tf.keras.layers.Input( type_ids = tf.keras.layers.Input(
shape=(None,), dtype=tf.int32, name='input_type_ids') shape=(None,), dtype=tf.int32, name='input_type_ids')
self.inputs = [input_ids, input_mask, type_ids] self.inputs = [input_ids, input_mask, type_ids]
attention_mask = layers.SelfAttentionMask()([input_ids, input_mask]) attention_mask = keras_nlp.layers.SelfAttentionMask()(input_ids, input_mask)
# build the computation graph # build the computation graph
all_layer_outputs = [] all_layer_outputs = []
......
...@@ -19,6 +19,7 @@ import collections ...@@ -19,6 +19,7 @@ import collections
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp import keras_nlp
from official.nlp.modeling import layers from official.nlp.modeling import layers
...@@ -139,7 +140,7 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -139,7 +140,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
name='embedding_projection')( name='embedding_projection')(
embeddings) embeddings)
attention_mask = layers.SelfAttentionMask()([embeddings, mask]) attention_mask = keras_nlp.layers.SelfAttentionMask()(embeddings, mask)
if sub_seq_mask is not None: if sub_seq_mask is not None:
attention_mask = tf.keras.layers.Lambda( attention_mask = tf.keras.layers.Lambda(
lambda x: x[0] * tf.cast(x[1], x[0].dtype))( lambda x: x[0] * tf.cast(x[1], x[0].dtype))(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment