"examples/vscode:/vscode.git/clone" did not exist on "df2e145e5f7fbd79979e883ee398b840a74694b3"
Commit 28f70bc4 authored by Zhenyu Tan's avatar Zhenyu Tan Committed by A. Unique TensorFlower
Browse files

Move SelfAttentionMask to keras_nlp

PiperOrigin-RevId: 330397061
parent c60951b1
...@@ -14,4 +14,5 @@ ...@@ -14,4 +14,5 @@
# ============================================================================== # ==============================================================================
"""Keras-NLP layers package definition.""" """Keras-NLP layers package definition."""
from official.nlp.keras_nlp.layers.position_embedding import PositionEmbedding from official.nlp.keras_nlp.layers.position_embedding import PositionEmbedding
from official.nlp.keras_nlp.layers.self_attention_mask import SelfAttentionMask
from official.nlp.keras_nlp.layers.transformer_encoder_block import TransformerEncoderBlock from official.nlp.keras_nlp.layers.transformer_encoder_block import TransformerEncoderBlock
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras layer that creates a self-attention mask."""
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package='keras_nlp')
class SelfAttentionMask(tf.keras.layers.Layer):
"""Create 3D attention mask from a 2D tensor mask.
inputs[0]: from_tensor: 2D or 3D Tensor of shape
[batch_size, from_seq_length, ...].
inputs[1]: to_mask: int32 Tensor of shape [batch_size, to_seq_length].
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
"""
def call(self, inputs, to_mask):
from_shape = tf.shape(inputs)
batch_size = from_shape[0]
from_seq_length = from_shape[1]
to_shape = tf.shape(to_mask)
to_seq_length = to_shape[1]
to_mask = tf.cast(
tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
dtype=inputs.dtype)
# We don't assume that `from_tensor` is a mask (although it could be). We
# don't actually care if we attend *from* padding tokens (only *to* padding)
# tokens so we create a tensor of all ones.
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones = tf.ones(
shape=[batch_size, from_seq_length, 1], dtype=inputs.dtype)
# Here we broadcast along two dimensions to create the mask.
mask = broadcast_ones * to_mask
return mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment