Commit c1ac2bfc authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 281086898
parent 5a3b762c
......@@ -89,7 +89,7 @@ def get_padding(x, padding_value=0, dtype=tf.float32):
Args:
x: int tensor with any shape
padding_value: int value that
padding_value: int which represents padded values in input
dtype: The dtype of the return value.
Returns:
......@@ -100,7 +100,7 @@ def get_padding(x, padding_value=0, dtype=tf.float32):
return tf.cast(tf.equal(x, padding_value), dtype)
def get_padding_bias(x):
def get_padding_bias(x, padding_value=0, dtype=tf.float32):
"""Calculate bias tensor from padding values in tensor.
Bias tensor that is added to the pre-softmax multi-headed attention logits,
......@@ -109,12 +109,14 @@ def get_padding_bias(x):
Args:
x: int tensor with shape [batch_size, length]
padding_value: int which represents padded values in input
dtype: The dtype of the return value
Returns:
Attention bias tensor of shape [batch_size, 1, 1, length].
"""
with tf.name_scope("attention_bias"):
padding = get_padding(x)
padding = get_padding(x, padding_value, dtype)
attention_bias = padding * _NEG_INF_FP32
attention_bias = tf.expand_dims(
tf.expand_dims(attention_bias, axis=1), axis=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment