"vscode:/vscode.git/clone" did not exist on "b6f8a268f527d3b3da92b75163d79641a7a2e189"
Commit c1ac2bfc authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 281086898
parent 5a3b762c
...@@ -89,7 +89,7 @@ def get_padding(x, padding_value=0, dtype=tf.float32): ...@@ -89,7 +89,7 @@ def get_padding(x, padding_value=0, dtype=tf.float32):
Args: Args:
x: int tensor with any shape x: int tensor with any shape
padding_value: int value that padding_value: int which represents padded values in input
dtype: The dtype of the return value. dtype: The dtype of the return value.
Returns: Returns:
...@@ -100,7 +100,7 @@ def get_padding(x, padding_value=0, dtype=tf.float32): ...@@ -100,7 +100,7 @@ def get_padding(x, padding_value=0, dtype=tf.float32):
return tf.cast(tf.equal(x, padding_value), dtype) return tf.cast(tf.equal(x, padding_value), dtype)
def get_padding_bias(x): def get_padding_bias(x, padding_value=0, dtype=tf.float32):
"""Calculate bias tensor from padding values in tensor. """Calculate bias tensor from padding values in tensor.
Bias tensor that is added to the pre-softmax multi-headed attention logits, Bias tensor that is added to the pre-softmax multi-headed attention logits,
...@@ -109,12 +109,14 @@ def get_padding_bias(x): ...@@ -109,12 +109,14 @@ def get_padding_bias(x):
Args: Args:
x: int tensor with shape [batch_size, length] x: int tensor with shape [batch_size, length]
padding_value: int which represents padded values in input
dtype: The dtype of the return value
Returns: Returns:
Attention bias tensor of shape [batch_size, 1, 1, length]. Attention bias tensor of shape [batch_size, 1, 1, length].
""" """
with tf.name_scope("attention_bias"): with tf.name_scope("attention_bias"):
padding = get_padding(x) padding = get_padding(x, padding_value, dtype)
attention_bias = padding * _NEG_INF_FP32 attention_bias = padding * _NEG_INF_FP32
attention_bias = tf.expand_dims( attention_bias = tf.expand_dims(
tf.expand_dims(attention_bias, axis=1), axis=1) tf.expand_dims(attention_bias, axis=1), axis=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment