Commit 46a12d4e authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 450124488
parent 97e6a524
...@@ -44,15 +44,5 @@ class SelfAttentionMask(tf.keras.layers.Layer): ...@@ -44,15 +44,5 @@ class SelfAttentionMask(tf.keras.layers.Layer):
tf.reshape(to_mask, [batch_size, 1, to_seq_length]), tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
dtype=inputs.dtype) dtype=inputs.dtype)
# We don't assume that `from_tensor` is a mask (although it could be). We return tf.broadcast_to(to_mask,
# don't actually care if we attend *from* padding tokens (only *to* padding) [batch_size, from_seq_length, to_seq_length])
# tokens so we create a tensor of all ones.
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones = tf.ones(
shape=[batch_size, from_seq_length, 1], dtype=inputs.dtype)
# Here we broadcast along two dimensions to create the mask.
mask = broadcast_ones * to_mask
return mask
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment