Commit 46a12d4e authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 450124488
parent 97e6a524
......@@ -44,15 +44,5 @@ class SelfAttentionMask(tf.keras.layers.Layer):
tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
dtype=inputs.dtype)
# We don't assume that `from_tensor` is a mask (although it could be). We
# don't actually care if we attend *from* padding tokens (only *to* padding)
# tokens so we create a tensor of all ones.
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones = tf.ones(
shape=[batch_size, from_seq_length, 1], dtype=inputs.dtype)
# Here we broadcast along two dimensions to create the mask.
mask = broadcast_ones * to_mask
return mask
return tf.broadcast_to(to_mask,
[batch_size, from_seq_length, to_seq_length])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment